mas_tasks/
new_queue.rs

1// Copyright 2024, 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::{collections::HashMap, sync::Arc};
7
8use async_trait::async_trait;
9use chrono::{DateTime, Duration, Utc};
10use cron::Schedule;
11use mas_context::LogContext;
12use mas_storage::{
13    Clock, RepositoryAccess, RepositoryError,
14    queue::{InsertableJob, Job, JobMetadata, Worker},
15};
16use mas_storage_pg::{DatabaseError, PgRepository};
17use opentelemetry::{
18    KeyValue,
19    metrics::{Counter, Histogram, UpDownCounter},
20};
21use rand::{Rng, RngCore, distributions::Uniform};
22use serde::de::DeserializeOwned;
23use sqlx::{
24    Acquire, Either,
25    postgres::{PgAdvisoryLock, PgListener},
26};
27use thiserror::Error;
28use tokio::{task::JoinSet, time::Instant};
29use tokio_util::sync::CancellationToken;
30use tracing::{Instrument as _, Span};
31use tracing_opentelemetry::OpenTelemetrySpanExt as _;
32use ulid::Ulid;
33
34use crate::{METER, State};
35
36type JobPayload = serde_json::Value;
37
38#[derive(Clone)]
39pub struct JobContext {
40    pub id: Ulid,
41    pub metadata: JobMetadata,
42    pub queue_name: String,
43    pub attempt: usize,
44    pub start: Instant,
45
46    #[expect(
47        dead_code,
48        reason = "we're not yet using this, but will be in the future"
49    )]
50    pub cancellation_token: CancellationToken,
51}
52
53impl JobContext {
54    pub fn span(&self) -> Span {
55        let span = tracing::info_span!(
56            parent: Span::none(),
57            "job.run",
58            job.id = %self.id,
59            job.queue.name = self.queue_name,
60            job.attempt = self.attempt,
61        );
62
63        span.add_link(self.metadata.span_context());
64
65        span
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum JobErrorDecision {
71    Retry,
72
73    #[default]
74    Fail,
75}
76
77impl std::fmt::Display for JobErrorDecision {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            Self::Retry => f.write_str("retry"),
81            Self::Fail => f.write_str("fail"),
82        }
83    }
84}
85
86#[derive(Debug, Error)]
87#[error("Job failed to run, will {decision}")]
88pub struct JobError {
89    decision: JobErrorDecision,
90    #[source]
91    error: anyhow::Error,
92}
93
94impl JobError {
95    pub fn retry<T: Into<anyhow::Error>>(error: T) -> Self {
96        Self {
97            decision: JobErrorDecision::Retry,
98            error: error.into(),
99        }
100    }
101
102    pub fn fail<T: Into<anyhow::Error>>(error: T) -> Self {
103        Self {
104            decision: JobErrorDecision::Fail,
105            error: error.into(),
106        }
107    }
108}
109
110pub trait FromJob {
111    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
112    where
113        Self: Sized;
114}
115
116impl<T> FromJob for T
117where
118    T: DeserializeOwned,
119{
120    fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
121        serde_json::from_value(payload).map_err(Into::into)
122    }
123}
124
125#[async_trait]
126pub trait RunnableJob: FromJob + Send + 'static {
127    async fn run(&self, state: &State, context: JobContext) -> Result<(), JobError>;
128}
129
130fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
131    Box::new(job)
132}
133
134#[derive(Debug, Error)]
135pub enum QueueRunnerError {
136    #[error("Failed to setup listener")]
137    SetupListener(#[source] sqlx::Error),
138
139    #[error("Failed to start transaction")]
140    StartTransaction(#[source] sqlx::Error),
141
142    #[error("Failed to commit transaction")]
143    CommitTransaction(#[source] sqlx::Error),
144
145    #[error("Failed to acquire leader lock")]
146    LeaderLock(#[source] sqlx::Error),
147
148    #[error(transparent)]
149    Repository(#[from] RepositoryError),
150
151    #[error(transparent)]
152    Database(#[from] DatabaseError),
153
154    #[error("Invalid schedule expression")]
155    InvalidSchedule(#[from] cron::error::Error),
156
157    #[error("Worker is not the leader")]
158    NotLeader,
159}
160
161// When the worker waits for a notification, we still want to wake it up every
162// second. Because we don't want all the workers to wake up at the same time, we
163// add a random jitter to the sleep duration, so they effectively sleep between
164// 0.9 and 1.1 seconds.
165const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
166const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
167
168// How many jobs can we run concurrently
169const MAX_CONCURRENT_JOBS: usize = 10;
170
171// How many jobs can we fetch at once
172const MAX_JOBS_TO_FETCH: usize = 5;
173
174// How many attempts a job should be retried
175const MAX_ATTEMPTS: usize = 10;
176
177/// Returns the delay to wait before retrying a job
178///
179/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
180/// 21m40s, 43m20s
181fn retry_delay(attempt: usize) -> Duration {
182    let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
183    Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
184}
185
186type JobResult = (std::time::Duration, Result<(), JobError>);
187type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
188
189struct ScheduleDefinition {
190    schedule_name: &'static str,
191    expression: Schedule,
192    queue_name: &'static str,
193    payload: serde_json::Value,
194}
195
196pub struct QueueWorker {
197    listener: PgListener,
198    registration: Worker,
199    am_i_leader: bool,
200    last_heartbeat: DateTime<Utc>,
201    cancellation_token: CancellationToken,
202    #[expect(dead_code, reason = "This is used on Drop")]
203    cancellation_guard: tokio_util::sync::DropGuard,
204    state: State,
205    schedules: Vec<ScheduleDefinition>,
206    tracker: JobTracker,
207    wakeup_reason: Counter<u64>,
208    tick_time: Histogram<u64>,
209}
210
211impl QueueWorker {
212    #[tracing::instrument(
213        name = "worker.init",
214        skip_all,
215        fields(worker.id)
216    )]
217    pub(crate) async fn new(
218        state: State,
219        cancellation_token: CancellationToken,
220    ) -> Result<Self, QueueRunnerError> {
221        let mut rng = state.rng();
222        let clock = state.clock();
223
224        let mut listener = PgListener::connect_with(&state.pool())
225            .await
226            .map_err(QueueRunnerError::SetupListener)?;
227
228        // We get notifications of leader stepping down on this channel
229        listener
230            .listen("queue_leader_stepdown")
231            .await
232            .map_err(QueueRunnerError::SetupListener)?;
233
234        // We get notifications when a job is available on this channel
235        listener
236            .listen("queue_available")
237            .await
238            .map_err(QueueRunnerError::SetupListener)?;
239
240        let txn = listener
241            .begin()
242            .await
243            .map_err(QueueRunnerError::StartTransaction)?;
244        let mut repo = PgRepository::from_conn(txn);
245
246        let registration = repo.queue_worker().register(&mut rng, clock).await?;
247        tracing::Span::current().record("worker.id", tracing::field::display(registration.id));
248        repo.into_inner()
249            .commit()
250            .await
251            .map_err(QueueRunnerError::CommitTransaction)?;
252
253        tracing::info!(worker.id = %registration.id, "Registered worker");
254        let now = clock.now();
255
256        let wakeup_reason = METER
257            .u64_counter("job.worker.wakeups")
258            .with_description("Counts how many time the worker has been woken up, for which reason")
259            .build();
260
261        // Pre-create the reasons on the counter
262        wakeup_reason.add(0, &[KeyValue::new("reason", "sleep")]);
263        wakeup_reason.add(0, &[KeyValue::new("reason", "task")]);
264        wakeup_reason.add(0, &[KeyValue::new("reason", "notification")]);
265
266        let tick_time = METER
267            .u64_histogram("job.worker.tick_duration")
268            .with_description(
269                "How much time the worker took to tick, including performing leader duties",
270            )
271            .build();
272
273        // We put a cancellation drop guard in the structure, so that when it gets
274        // dropped, we're sure to cancel the token
275        let cancellation_guard = cancellation_token.clone().drop_guard();
276
277        Ok(Self {
278            listener,
279            registration,
280            am_i_leader: false,
281            last_heartbeat: now,
282            cancellation_token,
283            cancellation_guard,
284            state,
285            schedules: Vec::new(),
286            tracker: JobTracker::new(),
287            wakeup_reason,
288            tick_time,
289        })
290    }
291
292    pub(crate) fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
293        // There is a potential panic here, which is fine as it's going to be caught
294        // within the job task
295        let factory = |payload: JobPayload| {
296            box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
297        };
298
299        self.tracker
300            .factories
301            .insert(T::QUEUE_NAME, Arc::new(factory));
302        self
303    }
304
305    pub(crate) fn add_schedule<T: InsertableJob>(
306        &mut self,
307        schedule_name: &'static str,
308        expression: Schedule,
309        job: T,
310    ) -> &mut Self {
311        let payload = serde_json::to_value(job).expect("failed to serialize job payload");
312
313        self.schedules.push(ScheduleDefinition {
314            schedule_name,
315            expression,
316            queue_name: T::QUEUE_NAME,
317            payload,
318        });
319
320        self
321    }
322
323    pub(crate) async fn run(mut self) {
324        if let Err(e) = self.run_inner().await {
325            tracing::error!(
326                error = &e as &dyn std::error::Error,
327                "Failed to run new queue"
328            );
329        }
330    }
331
332    async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
333        self.setup_schedules().await?;
334
335        while !self.cancellation_token.is_cancelled() {
336            LogContext::new("worker-run-loop")
337                .run(|| self.run_loop())
338                .await?;
339        }
340
341        self.shutdown().await?;
342
343        Ok(())
344    }
345
346    #[tracing::instrument(name = "worker.setup_schedules", skip_all)]
347    pub(crate) async fn setup_schedules(&mut self) -> Result<(), QueueRunnerError> {
348        let schedules: Vec<_> = self.schedules.iter().map(|s| s.schedule_name).collect();
349
350        // Start a transaction on the existing PgListener connection
351        let txn = self
352            .listener
353            .begin()
354            .await
355            .map_err(QueueRunnerError::StartTransaction)?;
356
357        let mut repo = PgRepository::from_conn(txn);
358
359        // Setup the entries in the queue_schedules table
360        repo.queue_schedule().setup(&schedules).await?;
361
362        repo.into_inner()
363            .commit()
364            .await
365            .map_err(QueueRunnerError::CommitTransaction)?;
366
367        Ok(())
368    }
369
370    #[tracing::instrument(name = "worker.run_loop", skip_all)]
371    async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
372        self.wait_until_wakeup().await?;
373
374        if self.cancellation_token.is_cancelled() {
375            return Ok(());
376        }
377
378        let start = Instant::now();
379        self.tick().await?;
380
381        if self.am_i_leader {
382            self.perform_leader_duties().await?;
383        }
384
385        let elapsed = start.elapsed().as_millis().try_into().unwrap_or(u64::MAX);
386        self.tick_time.record(elapsed, &[]);
387
388        Ok(())
389    }
390
391    #[tracing::instrument(name = "worker.shutdown", skip_all)]
392    async fn shutdown(&mut self) -> Result<(), QueueRunnerError> {
393        tracing::info!("Shutting down worker");
394
395        let clock = self.state.clock();
396        let mut rng = self.state.rng();
397
398        // Start a transaction on the existing PgListener connection
399        let txn = self
400            .listener
401            .begin()
402            .await
403            .map_err(QueueRunnerError::StartTransaction)?;
404
405        let mut repo = PgRepository::from_conn(txn);
406
407        // Log about any job still running
408        match self.tracker.running_jobs() {
409            0 => {}
410            1 => tracing::warn!("There is one job still running, waiting for it to finish"),
411            n => tracing::warn!("There are {n} jobs still running, waiting for them to finish"),
412        }
413
414        // TODO: we may want to introduce a timeout here, and abort the tasks if they
415        // take too long. It's fine for now, as we don't have long-running
416        // tasks, most of them are idempotent, and the only effect might be that
417        // the worker would 'dirtily' shutdown, meaning that its tasks would be
418        // considered, later retried by another worker
419
420        // Wait for all the jobs to finish
421        self.tracker
422            .process_jobs(&mut rng, clock, &mut repo, true)
423            .await?;
424
425        // Tell the other workers we're shutting down
426        // This also releases the leader election lease
427        repo.queue_worker()
428            .shutdown(clock, &self.registration)
429            .await?;
430
431        repo.into_inner()
432            .commit()
433            .await
434            .map_err(QueueRunnerError::CommitTransaction)?;
435
436        Ok(())
437    }
438
439    #[tracing::instrument(name = "worker.wait_until_wakeup", skip_all)]
440    async fn wait_until_wakeup(&mut self) -> Result<(), QueueRunnerError> {
441        let mut rng = self.state.rng();
442
443        // This is to make sure we wake up every second to do the maintenance tasks
444        // We add a little bit of random jitter to the duration, so that we don't get
445        // fully synced workers waking up at the same time after each notification
446        let sleep_duration = rng.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
447        let wakeup_sleep = tokio::time::sleep(sleep_duration);
448
449        tokio::select! {
450            () = self.cancellation_token.cancelled() => {
451                tracing::debug!("Woke up from cancellation");
452            },
453
454            () = wakeup_sleep => {
455                tracing::debug!("Woke up from sleep");
456                self.wakeup_reason.add(1, &[KeyValue::new("reason", "sleep")]);
457            },
458
459            () = self.tracker.collect_next_job(), if self.tracker.has_jobs() => {
460                tracing::debug!("Joined job task");
461                self.wakeup_reason.add(1, &[KeyValue::new("reason", "task")]);
462            },
463
464            notification = self.listener.recv() => {
465                self.wakeup_reason.add(1, &[KeyValue::new("reason", "notification")]);
466                match notification {
467                    Ok(notification) => {
468                        tracing::debug!(
469                            notification.channel = notification.channel(),
470                            notification.payload = notification.payload(),
471                            "Woke up from notification"
472                        );
473                    },
474                    Err(e) => {
475                        tracing::error!(error = &e as &dyn std::error::Error, "Failed to receive notification");
476                    },
477                }
478            },
479        }
480
481        Ok(())
482    }
483
484    #[tracing::instrument(
485        name = "worker.tick",
486        skip_all,
487        fields(worker.id = %self.registration.id),
488    )]
489    async fn tick(&mut self) -> Result<(), QueueRunnerError> {
490        tracing::debug!("Tick");
491        let clock = self.state.clock();
492        let mut rng = self.state.rng();
493        let now = clock.now();
494
495        // Start a transaction on the existing PgListener connection
496        let txn = self
497            .listener
498            .begin()
499            .await
500            .map_err(QueueRunnerError::StartTransaction)?;
501        let mut repo = PgRepository::from_conn(txn);
502
503        // We send a heartbeat every minute, to avoid writing to the database too often
504        // on a logged table
505        if now - self.last_heartbeat >= chrono::Duration::minutes(1) {
506            tracing::info!("Sending heartbeat");
507            repo.queue_worker()
508                .heartbeat(clock, &self.registration)
509                .await?;
510            self.last_heartbeat = now;
511        }
512
513        // Remove any dead worker leader leases
514        repo.queue_worker()
515            .remove_leader_lease_if_expired(clock)
516            .await?;
517
518        // Try to become (or stay) the leader
519        let leader = repo
520            .queue_worker()
521            .try_get_leader_lease(clock, &self.registration)
522            .await?;
523
524        // Process any job task which finished
525        self.tracker
526            .process_jobs(&mut rng, clock, &mut repo, false)
527            .await?;
528
529        // Compute how many jobs we should fetch at most
530        let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
531            .saturating_sub(self.tracker.running_jobs())
532            .max(MAX_JOBS_TO_FETCH);
533
534        if max_jobs_to_fetch == 0 {
535            tracing::warn!("Internal job queue is full, not fetching any new jobs");
536        } else {
537            // Grab a few jobs in the queue
538            let queues = self.tracker.queues();
539            let jobs = repo
540                .queue_job()
541                .reserve(clock, &self.registration, &queues, max_jobs_to_fetch)
542                .await?;
543
544            for Job {
545                id,
546                queue_name,
547                payload,
548                metadata,
549                attempt,
550            } in jobs
551            {
552                let cancellation_token = self.cancellation_token.child_token();
553                let start = Instant::now();
554                let context = JobContext {
555                    id,
556                    metadata,
557                    queue_name,
558                    attempt,
559                    start,
560                    cancellation_token,
561                };
562
563                self.tracker.spawn_job(self.state.clone(), context, payload);
564            }
565        }
566
567        // After this point, we are locking the leader table, so it's important that we
568        // commit as soon as possible to not block the other workers for too long
569        repo.into_inner()
570            .commit()
571            .await
572            .map_err(QueueRunnerError::CommitTransaction)?;
573
574        // Save the new leader state to log any change
575        if leader != self.am_i_leader {
576            // If we flipped state, log it
577            self.am_i_leader = leader;
578            if self.am_i_leader {
579                tracing::info!("I'm the leader now");
580            } else {
581                tracing::warn!("I am no longer the leader");
582            }
583        }
584
585        Ok(())
586    }
587
588    #[tracing::instrument(name = "worker.perform_leader_duties", skip_all)]
589    async fn perform_leader_duties(&mut self) -> Result<(), QueueRunnerError> {
590        // This should have been checked by the caller, but better safe than sorry
591        if !self.am_i_leader {
592            return Err(QueueRunnerError::NotLeader);
593        }
594
595        let clock = self.state.clock();
596        let mut rng = self.state.rng();
597
598        // Start a transaction on the existing PgListener connection
599        let txn = self
600            .listener
601            .begin()
602            .await
603            .map_err(QueueRunnerError::StartTransaction)?;
604
605        // The thing with the leader election is that it locks the table during the
606        // election, preventing other workers from going through the loop.
607        //
608        // Ideally, we would do the leader duties in the same transaction so that we
609        // make sure only one worker is doing the leader duties, but that
610        // would mean we would lock all the workers for the duration of the
611        // duties, which is not ideal.
612        //
613        // So we do the duties in a separate transaction, in which we take an advisory
614        // lock, so that in the very rare case where two workers think they are the
615        // leader, we still don't have two workers doing the duties at the same time.
616        let lock = PgAdvisoryLock::new("leader-duties");
617
618        let locked = lock
619            .try_acquire(txn)
620            .await
621            .map_err(QueueRunnerError::LeaderLock)?;
622
623        let locked = match locked {
624            Either::Left(locked) => locked,
625            Either::Right(txn) => {
626                tracing::error!("Another worker has the leader lock, aborting");
627                txn.rollback()
628                    .await
629                    .map_err(QueueRunnerError::CommitTransaction)?;
630                return Ok(());
631            }
632        };
633
634        let mut repo = PgRepository::from_conn(locked);
635
636        // Look at the state of schedules in the database
637        let schedules_status = repo.queue_schedule().list().await?;
638
639        let now = clock.now();
640        for schedule in &self.schedules {
641            // Find the schedule status from the database
642            let Some(schedule_status) = schedules_status
643                .iter()
644                .find(|s| s.schedule_name == schedule.schedule_name)
645            else {
646                tracing::error!(
647                    "Schedule {} was not found in the database",
648                    schedule.schedule_name
649                );
650                continue;
651            };
652
653            // Figure out if we should schedule a new job
654            if let Some(next_time) = schedule_status.last_scheduled_at {
655                if next_time > now {
656                    // We already have a job scheduled in the future, skip
657                    continue;
658                }
659
660                if schedule_status.last_scheduled_job_completed == Some(false) {
661                    // The last scheduled job has not completed yet, skip
662                    continue;
663                }
664            }
665
666            let next_tick = schedule.expression.after(&now).next().unwrap();
667
668            tracing::info!(
669                "Scheduling job for {}, next run at {}",
670                schedule.schedule_name,
671                next_tick
672            );
673
674            repo.queue_job()
675                .schedule_later(
676                    &mut rng,
677                    clock,
678                    schedule.queue_name,
679                    schedule.payload.clone(),
680                    serde_json::json!({}),
681                    next_tick,
682                    Some(schedule.schedule_name),
683                )
684                .await?;
685        }
686
687        // We also check if the worker is dead, and if so, we shutdown all the dead
688        // workers that haven't checked in the last two minutes
689        repo.queue_worker()
690            .shutdown_dead_workers(clock, Duration::minutes(2))
691            .await?;
692
693        // TODO: mark tasks those workers had as lost
694
695        // Mark all the scheduled jobs as available
696        let scheduled = repo.queue_job().schedule_available_jobs(clock).await?;
697        match scheduled {
698            0 => {}
699            1 => tracing::info!("One scheduled job marked as available"),
700            n => tracing::info!("{n} scheduled jobs marked as available"),
701        }
702
703        // Release the leader lock
704        let txn = repo
705            .into_inner()
706            .release_now()
707            .await
708            .map_err(QueueRunnerError::LeaderLock)?;
709
710        txn.commit()
711            .await
712            .map_err(QueueRunnerError::CommitTransaction)?;
713
714        Ok(())
715    }
716
717    /// Process all the pending jobs in the queue.
718    /// This should only be called in tests!
719    ///
720    /// # Errors
721    ///
722    /// This function can fail if the database connection fails.
723    pub async fn process_all_jobs_in_tests(&mut self) -> Result<(), QueueRunnerError> {
724        // I swear, I'm the leader!
725        self.am_i_leader = true;
726
727        // First, perform the leader duties. This will make sure that we schedule
728        // recurring jobs.
729        self.perform_leader_duties().await?;
730
731        let clock = self.state.clock();
732        let mut rng = self.state.rng();
733
734        // Grab the connection from the PgListener
735        let txn = self
736            .listener
737            .begin()
738            .await
739            .map_err(QueueRunnerError::StartTransaction)?;
740        let mut repo = PgRepository::from_conn(txn);
741
742        // Spawn all the jobs in the database
743        let queues = self.tracker.queues();
744        let jobs = repo
745            .queue_job()
746            // I really hope that we don't spawn more than 10k jobs in tests
747            .reserve(clock, &self.registration, &queues, 10_000)
748            .await?;
749
750        for Job {
751            id,
752            queue_name,
753            payload,
754            metadata,
755            attempt,
756        } in jobs
757        {
758            let cancellation_token = self.cancellation_token.child_token();
759            let start = Instant::now();
760            let context = JobContext {
761                id,
762                metadata,
763                queue_name,
764                attempt,
765                start,
766                cancellation_token,
767            };
768
769            self.tracker.spawn_job(self.state.clone(), context, payload);
770        }
771
772        self.tracker
773            .process_jobs(&mut rng, clock, &mut repo, true)
774            .await?;
775
776        repo.into_inner()
777            .commit()
778            .await
779            .map_err(QueueRunnerError::CommitTransaction)?;
780
781        Ok(())
782    }
783}
784
785/// Tracks running jobs
786///
787/// This is a separate structure to be able to borrow it mutably at the same
788/// time as the connection to the database is borrowed
789struct JobTracker {
790    /// Stores a mapping from the job queue name to the job factory
791    factories: HashMap<&'static str, JobFactory>,
792
793    /// A join set of all the currently running jobs
794    running_jobs: JoinSet<JobResult>,
795
796    /// Stores a mapping from the Tokio task ID to the job context
797    job_contexts: HashMap<tokio::task::Id, JobContext>,
798
799    /// Stores the last `join_next_with_id` result for processing, in case we
800    /// got woken up in `collect_next_job`
801    last_join_result: Option<Result<(tokio::task::Id, JobResult), tokio::task::JoinError>>,
802
803    /// An histogram which records the time it takes to process a job
804    job_processing_time: Histogram<u64>,
805
806    /// A counter which records the number of jobs currently in flight
807    in_flight_jobs: UpDownCounter<i64>,
808}
809
810impl JobTracker {
811    fn new() -> Self {
812        let job_processing_time = METER
813            .u64_histogram("job.process.duration")
814            .with_description("The time it takes to process a job in milliseconds")
815            .with_unit("ms")
816            .build();
817
818        let in_flight_jobs = METER
819            .i64_up_down_counter("job.active_tasks")
820            .with_description("The number of jobs currently in flight")
821            .with_unit("{job}")
822            .build();
823
824        Self {
825            factories: HashMap::new(),
826            running_jobs: JoinSet::new(),
827            job_contexts: HashMap::new(),
828            last_join_result: None,
829            job_processing_time,
830            in_flight_jobs,
831        }
832    }
833
834    /// Returns the queue names that are currently being tracked
835    fn queues(&self) -> Vec<&'static str> {
836        self.factories.keys().copied().collect()
837    }
838
839    /// Spawn a job on the job tracker
840    fn spawn_job(&mut self, state: State, context: JobContext, payload: JobPayload) {
841        let factory = self.factories.get(context.queue_name.as_str()).cloned();
842        let task = {
843            let log_context = LogContext::new(format!("job-{}", context.queue_name));
844            let context = context.clone();
845            let span = context.span();
846            log_context
847                .run(async move || {
848                    // We should never crash, but in case we do, we do that in the task and
849                    // don't crash the worker
850                    let job = factory.expect("unknown job factory")(payload);
851                    tracing::info!(
852                        job.id = %context.id,
853                        job.queue.name = %context.queue_name,
854                        job.attempt = %context.attempt,
855                        "Running job"
856                    );
857                    let result = job.run(&state, context.clone()).await;
858
859                    let Some(context_stats) =
860                        LogContext::maybe_with(mas_context::LogContext::stats)
861                    else {
862                        // This should never happen, but if it does it's fine: we're recovering fine
863                        // from panics in those tasks
864                        panic!("Missing log context, this should never happen");
865                    };
866
867                    // We log the result here so that it's attached to the right span & log context
868                    match &result {
869                        Ok(()) => {
870                            tracing::info!(
871                                job.id = %context.id,
872                                job.queue.name = %context.queue_name,
873                                job.attempt = %context.attempt,
874                                "Job completed [{context_stats}]"
875                            );
876                        }
877
878                        Err(JobError {
879                            decision: JobErrorDecision::Fail,
880                            error,
881                        }) => {
882                            tracing::error!(
883                                error = &**error as &dyn std::error::Error,
884                                job.id = %context.id,
885                                job.queue.name = %context.queue_name,
886                                job.attempt = %context.attempt,
887                                "Job failed, not retrying [{context_stats}]"
888                            );
889                        }
890
891                        Err(JobError {
892                            decision: JobErrorDecision::Retry,
893                            error,
894                        }) if context.attempt < MAX_ATTEMPTS => {
895                            let delay = retry_delay(context.attempt);
896                            tracing::warn!(
897                                error = &**error as &dyn std::error::Error,
898                                job.id = %context.id,
899                                job.queue.name = %context.queue_name,
900                                job.attempt = %context.attempt,
901                                "Job failed, will retry in {}s [{context_stats}]",
902                                delay.num_seconds()
903                            );
904                        }
905
906                        Err(JobError {
907                            decision: JobErrorDecision::Retry,
908                            error,
909                        }) => {
910                            tracing::error!(
911                                error = &**error as &dyn std::error::Error,
912                                job.id = %context.id,
913                                job.queue.name = %context.queue_name,
914                                job.attempt = %context.attempt,
915                                "Job failed too many times, abandonning [{context_stats}]"
916                            );
917                        }
918                    }
919
920                    (context_stats.elapsed, result)
921                })
922                .instrument(span)
923        };
924
925        self.in_flight_jobs.add(
926            1,
927            &[KeyValue::new("job.queue.name", context.queue_name.clone())],
928        );
929
930        let handle = self.running_jobs.spawn(task);
931        self.job_contexts.insert(handle.id(), context);
932    }
933
934    /// Returns `true` if there are currently running jobs
935    fn has_jobs(&self) -> bool {
936        !self.running_jobs.is_empty()
937    }
938
939    /// Returns the number of currently running jobs
940    ///
941    /// This also includes the job result which may be stored for processing
942    fn running_jobs(&self) -> usize {
943        self.running_jobs.len() + usize::from(self.last_join_result.is_some())
944    }
945
946    async fn collect_next_job(&mut self) {
947        // Double-check that we don't have a job result stored
948        if self.last_join_result.is_some() {
949            tracing::error!(
950                "Job tracker already had a job result stored, this should never happen!"
951            );
952            return;
953        }
954
955        self.last_join_result = self.running_jobs.join_next_with_id().await;
956    }
957
958    /// Process all the jobs which are currently running
959    ///
960    /// If `blocking` is `true`, this function will block until all the jobs
961    /// are finished. Otherwise, it will return as soon as it processed the
962    /// already finished jobs.
963    #[allow(clippy::too_many_lines)]
964    async fn process_jobs<E: std::error::Error + Send + Sync + 'static>(
965        &mut self,
966        rng: &mut (dyn RngCore + Send),
967        clock: &dyn Clock,
968        repo: &mut dyn RepositoryAccess<Error = E>,
969        blocking: bool,
970    ) -> Result<(), E> {
971        if self.last_join_result.is_none() {
972            if blocking {
973                self.last_join_result = self.running_jobs.join_next_with_id().await;
974            } else {
975                self.last_join_result = self.running_jobs.try_join_next_with_id();
976            }
977        }
978
979        while let Some(result) = self.last_join_result.take() {
980            match result {
981                // The job succeeded. The logging and time measurement is already done in the task
982                Ok((id, (elapsed, Ok(())))) => {
983                    let context = self
984                        .job_contexts
985                        .remove(&id)
986                        .expect("Job context not found");
987
988                    self.in_flight_jobs.add(
989                        -1,
990                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
991                    );
992
993                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
994                    self.job_processing_time.record(
995                        elapsed_ms,
996                        &[
997                            KeyValue::new("job.queue.name", context.queue_name),
998                            KeyValue::new("job.result", "success"),
999                        ],
1000                    );
1001
1002                    repo.queue_job()
1003                        .mark_as_completed(clock, context.id)
1004                        .await?;
1005                }
1006
1007                // The job failed. The logging and time measurement is already done in the task
1008                Ok((id, (elapsed, Err(e)))) => {
1009                    let context = self
1010                        .job_contexts
1011                        .remove(&id)
1012                        .expect("Job context not found");
1013
1014                    self.in_flight_jobs.add(
1015                        -1,
1016                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1017                    );
1018
1019                    let reason = format!("{:?}", e.error);
1020                    repo.queue_job()
1021                        .mark_as_failed(clock, context.id, &reason)
1022                        .await?;
1023
1024                    let elapsed_ms = elapsed.as_millis().try_into().unwrap_or(u64::MAX);
1025                    match e.decision {
1026                        JobErrorDecision::Fail => {
1027                            self.job_processing_time.record(
1028                                elapsed_ms,
1029                                &[
1030                                    KeyValue::new("job.queue.name", context.queue_name),
1031                                    KeyValue::new("job.result", "failed"),
1032                                    KeyValue::new("job.decision", "fail"),
1033                                ],
1034                            );
1035                        }
1036
1037                        JobErrorDecision::Retry if context.attempt < MAX_ATTEMPTS => {
1038                            self.job_processing_time.record(
1039                                elapsed_ms,
1040                                &[
1041                                    KeyValue::new("job.queue.name", context.queue_name),
1042                                    KeyValue::new("job.result", "failed"),
1043                                    KeyValue::new("job.decision", "retry"),
1044                                ],
1045                            );
1046
1047                            let delay = retry_delay(context.attempt);
1048                            repo.queue_job()
1049                                .retry(&mut *rng, clock, context.id, delay)
1050                                .await?;
1051                        }
1052
1053                        JobErrorDecision::Retry => {
1054                            self.job_processing_time.record(
1055                                elapsed_ms,
1056                                &[
1057                                    KeyValue::new("job.queue.name", context.queue_name),
1058                                    KeyValue::new("job.result", "failed"),
1059                                    KeyValue::new("job.decision", "abandon"),
1060                                ],
1061                            );
1062                        }
1063                    }
1064                }
1065
1066                // The job crashed (or was aborted)
1067                Err(e) => {
1068                    let id = e.id();
1069                    let context = self
1070                        .job_contexts
1071                        .remove(&id)
1072                        .expect("Job context not found");
1073
1074                    self.in_flight_jobs.add(
1075                        -1,
1076                        &[KeyValue::new("job.queue.name", context.queue_name.clone())],
1077                    );
1078
1079                    // This measurement is not accurate as it includes the time processing the jobs,
1080                    // but it's fine, it's only for panicked tasks
1081                    let elapsed = context
1082                        .start
1083                        .elapsed()
1084                        .as_millis()
1085                        .try_into()
1086                        .unwrap_or(u64::MAX);
1087
1088                    let reason = e.to_string();
1089                    repo.queue_job()
1090                        .mark_as_failed(clock, context.id, &reason)
1091                        .await?;
1092
1093                    if context.attempt < MAX_ATTEMPTS {
1094                        let delay = retry_delay(context.attempt);
1095                        tracing::error!(
1096                            error = &e as &dyn std::error::Error,
1097                            job.id = %context.id,
1098                            job.queue.name = %context.queue_name,
1099                            job.attempt = %context.attempt,
1100                            job.elapsed = format!("{elapsed}ms"),
1101                            "Job crashed, will retry in {}s",
1102                            delay.num_seconds()
1103                        );
1104
1105                        self.job_processing_time.record(
1106                            elapsed,
1107                            &[
1108                                KeyValue::new("job.queue.name", context.queue_name),
1109                                KeyValue::new("job.result", "crashed"),
1110                                KeyValue::new("job.decision", "retry"),
1111                            ],
1112                        );
1113
1114                        repo.queue_job()
1115                            .retry(&mut *rng, clock, context.id, delay)
1116                            .await?;
1117                    } else {
1118                        tracing::error!(
1119                            error = &e as &dyn std::error::Error,
1120                            job.id = %context.id,
1121                            job.queue.name = %context.queue_name,
1122                            job.attempt = %context.attempt,
1123                            job.elapsed = format!("{elapsed}ms"),
1124                            "Job crashed too many times, abandonning"
1125                        );
1126
1127                        self.job_processing_time.record(
1128                            elapsed,
1129                            &[
1130                                KeyValue::new("job.queue.name", context.queue_name),
1131                                KeyValue::new("job.result", "crashed"),
1132                                KeyValue::new("job.decision", "abandon"),
1133                            ],
1134                        );
1135                    }
1136                }
1137            }
1138
1139            if blocking {
1140                self.last_join_result = self.running_jobs.join_next_with_id().await;
1141            } else {
1142                self.last_join_result = self.running_jobs.try_join_next_with_id();
1143            }
1144        }
1145
1146        Ok(())
1147    }
1148}