mas_storage_pg/oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13    Clock, Page, Pagination,
14    oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{Expr, PgFunc, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError, DatabaseInconsistencyError,
26    filter::{Filter, StatementExt},
27    iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
33pub struct PgOAuth2SessionRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgOAuth2SessionRepository<'c> {
38    /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct OAuthSessionLookup {
48    oauth2_session_id: Uuid,
49    user_id: Option<Uuid>,
50    user_session_id: Option<Uuid>,
51    oauth2_client_id: Uuid,
52    scope_list: Vec<String>,
53    created_at: DateTime<Utc>,
54    finished_at: Option<DateTime<Utc>>,
55    user_agent: Option<String>,
56    last_active_at: Option<DateTime<Utc>>,
57    last_active_ip: Option<IpAddr>,
58    human_name: Option<String>,
59}
60
61impl TryFrom<OAuthSessionLookup> for Session {
62    type Error = DatabaseInconsistencyError;
63
64    fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
65        let id = Ulid::from(value.oauth2_session_id);
66        let scope: Result<Scope, _> = value
67            .scope_list
68            .iter()
69            .map(|s| s.parse::<ScopeToken>())
70            .collect();
71        let scope = scope.map_err(|e| {
72            DatabaseInconsistencyError::on("oauth2_sessions")
73                .column("scope")
74                .row(id)
75                .source(e)
76        })?;
77
78        let state = match value.finished_at {
79            None => SessionState::Valid,
80            Some(finished_at) => SessionState::Finished { finished_at },
81        };
82
83        Ok(Session {
84            id,
85            state,
86            created_at: value.created_at,
87            client_id: value.oauth2_client_id.into(),
88            user_id: value.user_id.map(Ulid::from),
89            user_session_id: value.user_session_id.map(Ulid::from),
90            scope,
91            user_agent: value.user_agent,
92            last_active_at: value.last_active_at,
93            last_active_ip: value.last_active_ip,
94            human_name: value.human_name,
95        })
96    }
97}
98
99impl Filter for OAuth2SessionFilter<'_> {
100    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101        sea_query::Condition::all()
102            .add_option(self.user().map(|user| {
103                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
104            }))
105            .add_option(self.client().map(|client| {
106                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
107                    .eq(Uuid::from(client.id))
108            }))
109            .add_option(self.client_kind().map(|client_kind| {
110                // This builds either a:
111                // `WHERE oauth2_client_id = ANY(...)`
112                // or a `WHERE oauth2_client_id <> ALL(...)`
113                let static_clients = Query::select()
114                    .expr(Expr::col((
115                        OAuth2Clients::Table,
116                        OAuth2Clients::OAuth2ClientId,
117                    )))
118                    .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
119                    .from(OAuth2Clients::Table)
120                    .take();
121                if client_kind.is_static() {
122                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
123                        .eq(Expr::any(static_clients))
124                } else {
125                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126                        .ne(Expr::all(static_clients))
127                }
128            }))
129            .add_option(self.device().map(|device| {
130                if let Ok(scope_token) = device.to_scope_token() {
131                    Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
132                        OAuth2Sessions::Table,
133                        OAuth2Sessions::ScopeList,
134                    ))))
135                } else {
136                    // If the device ID can't be encoded as a scope token, match no rows
137                    Expr::val(false).into()
138                }
139            }))
140            .add_option(self.browser_session().map(|browser_session| {
141                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
142                    .eq(Uuid::from(browser_session.id))
143            }))
144            .add_option(self.browser_session_filter().map(|browser_session_filter| {
145                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
146                    Query::select()
147                        .expr(Expr::col((
148                            UserSessions::Table,
149                            UserSessions::UserSessionId,
150                        )))
151                        .apply_filter(browser_session_filter)
152                        .from(UserSessions::Table)
153                        .take(),
154                )
155            }))
156            .add_option(self.state().map(|state| {
157                if state.is_active() {
158                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
159                } else {
160                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
161                }
162            }))
163            .add_option(self.scope().map(|scope| {
164                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
165                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
166            }))
167            .add_option(self.any_user().map(|any_user| {
168                if any_user {
169                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
170                } else {
171                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
172                }
173            }))
174            .add_option(self.last_active_after().map(|last_active_after| {
175                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
176                    .gt(last_active_after)
177            }))
178            .add_option(self.last_active_before().map(|last_active_before| {
179                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
180                    .lt(last_active_before)
181            }))
182    }
183}
184
185#[async_trait]
186impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
187    type Error = DatabaseError;
188
189    #[tracing::instrument(
190        name = "db.oauth2_session.lookup",
191        skip_all,
192        fields(
193            db.query.text,
194            session.id = %id,
195        ),
196        err,
197    )]
198    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
199        let res = sqlx::query_as!(
200            OAuthSessionLookup,
201            r#"
202                SELECT oauth2_session_id
203                     , user_id
204                     , user_session_id
205                     , oauth2_client_id
206                     , scope_list
207                     , created_at
208                     , finished_at
209                     , user_agent
210                     , last_active_at
211                     , last_active_ip as "last_active_ip: IpAddr"
212                     , human_name
213                FROM oauth2_sessions
214
215                WHERE oauth2_session_id = $1
216            "#,
217            Uuid::from(id),
218        )
219        .traced()
220        .fetch_optional(&mut *self.conn)
221        .await?;
222
223        let Some(session) = res else { return Ok(None) };
224
225        Ok(Some(session.try_into()?))
226    }
227
228    #[tracing::instrument(
229        name = "db.oauth2_session.add",
230        skip_all,
231        fields(
232            db.query.text,
233            %client.id,
234            session.id,
235            session.scope = %scope,
236        ),
237        err,
238    )]
239    async fn add(
240        &mut self,
241        rng: &mut (dyn RngCore + Send),
242        clock: &dyn Clock,
243        client: &Client,
244        user: Option<&User>,
245        user_session: Option<&BrowserSession>,
246        scope: Scope,
247    ) -> Result<Session, Self::Error> {
248        let created_at = clock.now();
249        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
250        tracing::Span::current().record("session.id", tracing::field::display(id));
251
252        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
253
254        sqlx::query!(
255            r#"
256                INSERT INTO oauth2_sessions
257                    ( oauth2_session_id
258                    , user_id
259                    , user_session_id
260                    , oauth2_client_id
261                    , scope_list
262                    , created_at
263                    )
264                VALUES ($1, $2, $3, $4, $5, $6)
265            "#,
266            Uuid::from(id),
267            user.map(|u| Uuid::from(u.id)),
268            user_session.map(|s| Uuid::from(s.id)),
269            Uuid::from(client.id),
270            &scope_list,
271            created_at,
272        )
273        .traced()
274        .execute(&mut *self.conn)
275        .await?;
276
277        Ok(Session {
278            id,
279            state: SessionState::Valid,
280            created_at,
281            user_id: user.map(|u| u.id),
282            user_session_id: user_session.map(|s| s.id),
283            client_id: client.id,
284            scope,
285            user_agent: None,
286            last_active_at: None,
287            last_active_ip: None,
288            human_name: None,
289        })
290    }
291
292    #[tracing::instrument(
293        name = "db.oauth2_session.finish_bulk",
294        skip_all,
295        fields(
296            db.query.text,
297        ),
298        err,
299    )]
300    async fn finish_bulk(
301        &mut self,
302        clock: &dyn Clock,
303        filter: OAuth2SessionFilter<'_>,
304    ) -> Result<usize, Self::Error> {
305        let finished_at = clock.now();
306        let (sql, arguments) = Query::update()
307            .table(OAuth2Sessions::Table)
308            .value(OAuth2Sessions::FinishedAt, finished_at)
309            .apply_filter(filter)
310            .build_sqlx(PostgresQueryBuilder);
311
312        let res = sqlx::query_with(&sql, arguments)
313            .traced()
314            .execute(&mut *self.conn)
315            .await?;
316
317        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
318    }
319
320    #[tracing::instrument(
321        name = "db.oauth2_session.finish",
322        skip_all,
323        fields(
324            db.query.text,
325            %session.id,
326            %session.scope,
327            client.id = %session.client_id,
328        ),
329        err,
330    )]
331    async fn finish(
332        &mut self,
333        clock: &dyn Clock,
334        session: Session,
335    ) -> Result<Session, Self::Error> {
336        let finished_at = clock.now();
337        let res = sqlx::query!(
338            r#"
339                UPDATE oauth2_sessions
340                SET finished_at = $2
341                WHERE oauth2_session_id = $1
342            "#,
343            Uuid::from(session.id),
344            finished_at,
345        )
346        .traced()
347        .execute(&mut *self.conn)
348        .await?;
349
350        DatabaseError::ensure_affected_rows(&res, 1)?;
351
352        session
353            .finish(finished_at)
354            .map_err(DatabaseError::to_invalid_operation)
355    }
356
357    #[tracing::instrument(
358        name = "db.oauth2_session.list",
359        skip_all,
360        fields(
361            db.query.text,
362        ),
363        err,
364    )]
365    async fn list(
366        &mut self,
367        filter: OAuth2SessionFilter<'_>,
368        pagination: Pagination,
369    ) -> Result<Page<Session>, Self::Error> {
370        let (sql, arguments) = Query::select()
371            .expr_as(
372                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
373                OAuthSessionLookupIden::Oauth2SessionId,
374            )
375            .expr_as(
376                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
377                OAuthSessionLookupIden::UserId,
378            )
379            .expr_as(
380                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
381                OAuthSessionLookupIden::UserSessionId,
382            )
383            .expr_as(
384                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
385                OAuthSessionLookupIden::Oauth2ClientId,
386            )
387            .expr_as(
388                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
389                OAuthSessionLookupIden::ScopeList,
390            )
391            .expr_as(
392                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
393                OAuthSessionLookupIden::CreatedAt,
394            )
395            .expr_as(
396                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
397                OAuthSessionLookupIden::FinishedAt,
398            )
399            .expr_as(
400                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
401                OAuthSessionLookupIden::UserAgent,
402            )
403            .expr_as(
404                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
405                OAuthSessionLookupIden::LastActiveAt,
406            )
407            .expr_as(
408                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
409                OAuthSessionLookupIden::LastActiveIp,
410            )
411            .expr_as(
412                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
413                OAuthSessionLookupIden::HumanName,
414            )
415            .from(OAuth2Sessions::Table)
416            .apply_filter(filter)
417            .generate_pagination(
418                (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
419                pagination,
420            )
421            .build_sqlx(PostgresQueryBuilder);
422
423        let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
424            .traced()
425            .fetch_all(&mut *self.conn)
426            .await?;
427
428        let page = pagination.process(edges).try_map(Session::try_from)?;
429
430        Ok(page)
431    }
432
433    #[tracing::instrument(
434        name = "db.oauth2_session.count",
435        skip_all,
436        fields(
437            db.query.text,
438        ),
439        err,
440    )]
441    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
442        let (sql, arguments) = Query::select()
443            .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
444            .from(OAuth2Sessions::Table)
445            .apply_filter(filter)
446            .build_sqlx(PostgresQueryBuilder);
447
448        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
449            .traced()
450            .fetch_one(&mut *self.conn)
451            .await?;
452
453        count
454            .try_into()
455            .map_err(DatabaseError::to_invalid_operation)
456    }
457
458    #[tracing::instrument(
459        name = "db.oauth2_session.record_batch_activity",
460        skip_all,
461        fields(
462            db.query.text,
463        ),
464        err,
465    )]
466    async fn record_batch_activity(
467        &mut self,
468        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
469    ) -> Result<(), Self::Error> {
470        // Sort the activity by ID, so that when batching the updates, Postgres
471        // locks the rows in a stable order, preventing deadlocks
472        activities.sort_unstable();
473        let mut ids = Vec::with_capacity(activities.len());
474        let mut last_activities = Vec::with_capacity(activities.len());
475        let mut ips = Vec::with_capacity(activities.len());
476
477        for (id, last_activity, ip) in activities {
478            ids.push(Uuid::from(id));
479            last_activities.push(last_activity);
480            ips.push(ip);
481        }
482
483        let res = sqlx::query!(
484            r#"
485                UPDATE oauth2_sessions
486                SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
487                  , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
488                FROM (
489                    SELECT *
490                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
491                        AS t(oauth2_session_id, last_active_at, last_active_ip)
492                ) AS t
493                WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
494            "#,
495            &ids,
496            &last_activities,
497            &ips as &[Option<IpAddr>],
498        )
499        .traced()
500        .execute(&mut *self.conn)
501        .await?;
502
503        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
504
505        Ok(())
506    }
507
508    #[tracing::instrument(
509        name = "db.oauth2_session.record_user_agent",
510        skip_all,
511        fields(
512            db.query.text,
513            %session.id,
514            %session.scope,
515            client.id = %session.client_id,
516            session.user_agent = user_agent,
517        ),
518        err,
519    )]
520    async fn record_user_agent(
521        &mut self,
522        mut session: Session,
523        user_agent: String,
524    ) -> Result<Session, Self::Error> {
525        let res = sqlx::query!(
526            r#"
527                UPDATE oauth2_sessions
528                SET user_agent = $2
529                WHERE oauth2_session_id = $1
530            "#,
531            Uuid::from(session.id),
532            &*user_agent,
533        )
534        .traced()
535        .execute(&mut *self.conn)
536        .await?;
537
538        session.user_agent = Some(user_agent);
539
540        DatabaseError::ensure_affected_rows(&res, 1)?;
541
542        Ok(session)
543    }
544
545    #[tracing::instrument(
546        name = "repository.oauth2_session.set_human_name",
547        skip(self),
548        fields(
549            client.id = %session.client_id,
550            session.human_name = ?human_name,
551        ),
552        err,
553    )]
554    async fn set_human_name(
555        &mut self,
556        mut session: Session,
557        human_name: Option<String>,
558    ) -> Result<Session, Self::Error> {
559        let res = sqlx::query!(
560            r#"
561                UPDATE oauth2_sessions
562                SET human_name = $2
563                WHERE oauth2_session_id = $1
564            "#,
565            Uuid::from(session.id),
566            human_name.as_deref(),
567        )
568        .traced()
569        .execute(&mut *self.conn)
570        .await?;
571
572        session.human_name = human_name;
573
574        DatabaseError::ensure_affected_rows(&res, 1)?;
575
576        Ok(session)
577    }
578}