1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11 Clock, Page, Pagination,
12 upstream_oauth2::{
13 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14 },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError, DatabaseInconsistencyError,
27 filter::{Filter, StatementExt},
28 iden::UpstreamOAuthProviders,
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33pub struct PgUpstreamOAuthProviderRepository<'c> {
36 conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40 pub fn new(conn: &'c mut PgConnection) -> Self {
43 Self { conn }
44 }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50 upstream_oauth_provider_id: Uuid,
51 issuer: Option<String>,
52 human_name: Option<String>,
53 brand_name: Option<String>,
54 scope: String,
55 client_id: String,
56 encrypted_client_secret: Option<String>,
57 token_endpoint_signing_alg: Option<String>,
58 token_endpoint_auth_method: String,
59 id_token_signed_response_alg: String,
60 fetch_userinfo: bool,
61 userinfo_signed_response_alg: Option<String>,
62 created_at: DateTime<Utc>,
63 disabled_at: Option<DateTime<Utc>>,
64 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65 jwks_uri_override: Option<String>,
66 authorization_endpoint_override: Option<String>,
67 token_endpoint_override: Option<String>,
68 userinfo_endpoint_override: Option<String>,
69 discovery_mode: String,
70 pkce_mode: String,
71 response_mode: Option<String>,
72 additional_parameters: Option<Json<Vec<(String, String)>>>,
73 forward_login_hint: bool,
74 on_backchannel_logout: String,
75}
76
77impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
78 type Error = DatabaseInconsistencyError;
79
80 #[allow(clippy::too_many_lines)]
81 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
82 let id = value.upstream_oauth_provider_id.into();
83 let scope = value.scope.parse().map_err(|e| {
84 DatabaseInconsistencyError::on("upstream_oauth_providers")
85 .column("scope")
86 .row(id)
87 .source(e)
88 })?;
89 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
90 DatabaseInconsistencyError::on("upstream_oauth_providers")
91 .column("token_endpoint_auth_method")
92 .row(id)
93 .source(e)
94 })?;
95 let token_endpoint_signing_alg = value
96 .token_endpoint_signing_alg
97 .map(|x| x.parse())
98 .transpose()
99 .map_err(|e| {
100 DatabaseInconsistencyError::on("upstream_oauth_providers")
101 .column("token_endpoint_signing_alg")
102 .row(id)
103 .source(e)
104 })?;
105 let id_token_signed_response_alg =
106 value.id_token_signed_response_alg.parse().map_err(|e| {
107 DatabaseInconsistencyError::on("upstream_oauth_providers")
108 .column("id_token_signed_response_alg")
109 .row(id)
110 .source(e)
111 })?;
112
113 let userinfo_signed_response_alg = value
114 .userinfo_signed_response_alg
115 .map(|x| x.parse())
116 .transpose()
117 .map_err(|e| {
118 DatabaseInconsistencyError::on("upstream_oauth_providers")
119 .column("userinfo_signed_response_alg")
120 .row(id)
121 .source(e)
122 })?;
123
124 let authorization_endpoint_override = value
125 .authorization_endpoint_override
126 .map(|x| x.parse())
127 .transpose()
128 .map_err(|e| {
129 DatabaseInconsistencyError::on("upstream_oauth_providers")
130 .column("authorization_endpoint_override")
131 .row(id)
132 .source(e)
133 })?;
134
135 let token_endpoint_override = value
136 .token_endpoint_override
137 .map(|x| x.parse())
138 .transpose()
139 .map_err(|e| {
140 DatabaseInconsistencyError::on("upstream_oauth_providers")
141 .column("token_endpoint_override")
142 .row(id)
143 .source(e)
144 })?;
145
146 let userinfo_endpoint_override = value
147 .userinfo_endpoint_override
148 .map(|x| x.parse())
149 .transpose()
150 .map_err(|e| {
151 DatabaseInconsistencyError::on("upstream_oauth_providers")
152 .column("userinfo_endpoint_override")
153 .row(id)
154 .source(e)
155 })?;
156
157 let jwks_uri_override = value
158 .jwks_uri_override
159 .map(|x| x.parse())
160 .transpose()
161 .map_err(|e| {
162 DatabaseInconsistencyError::on("upstream_oauth_providers")
163 .column("jwks_uri_override")
164 .row(id)
165 .source(e)
166 })?;
167
168 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
169 DatabaseInconsistencyError::on("upstream_oauth_providers")
170 .column("discovery_mode")
171 .row(id)
172 .source(e)
173 })?;
174
175 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
176 DatabaseInconsistencyError::on("upstream_oauth_providers")
177 .column("pkce_mode")
178 .row(id)
179 .source(e)
180 })?;
181
182 let response_mode = value
183 .response_mode
184 .map(|x| x.parse())
185 .transpose()
186 .map_err(|e| {
187 DatabaseInconsistencyError::on("upstream_oauth_providers")
188 .column("response_mode")
189 .row(id)
190 .source(e)
191 })?;
192
193 let additional_authorization_parameters = value
194 .additional_parameters
195 .map(|Json(x)| x)
196 .unwrap_or_default();
197
198 let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
199 DatabaseInconsistencyError::on("upstream_oauth_providers")
200 .column("on_backchannel_logout")
201 .row(id)
202 .source(e)
203 })?;
204
205 Ok(UpstreamOAuthProvider {
206 id,
207 issuer: value.issuer,
208 human_name: value.human_name,
209 brand_name: value.brand_name,
210 scope,
211 client_id: value.client_id,
212 encrypted_client_secret: value.encrypted_client_secret,
213 token_endpoint_auth_method,
214 token_endpoint_signing_alg,
215 id_token_signed_response_alg,
216 fetch_userinfo: value.fetch_userinfo,
217 userinfo_signed_response_alg,
218 created_at: value.created_at,
219 disabled_at: value.disabled_at,
220 claims_imports: value.claims_imports.0,
221 authorization_endpoint_override,
222 token_endpoint_override,
223 userinfo_endpoint_override,
224 jwks_uri_override,
225 discovery_mode,
226 pkce_mode,
227 response_mode,
228 additional_authorization_parameters,
229 forward_login_hint: value.forward_login_hint,
230 on_backchannel_logout,
231 })
232 }
233}
234
235impl Filter for UpstreamOAuthProviderFilter<'_> {
236 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
237 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
238 Expr::col((
239 UpstreamOAuthProviders::Table,
240 UpstreamOAuthProviders::DisabledAt,
241 ))
242 .is_null()
243 .eq(enabled)
244 }))
245 }
246}
247
248#[async_trait]
249impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
250 type Error = DatabaseError;
251
252 #[tracing::instrument(
253 name = "db.upstream_oauth_provider.lookup",
254 skip_all,
255 fields(
256 db.query.text,
257 upstream_oauth_provider.id = %id,
258 ),
259 err,
260 )]
261 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
262 let res = sqlx::query_as!(
263 ProviderLookup,
264 r#"
265 SELECT
266 upstream_oauth_provider_id,
267 issuer,
268 human_name,
269 brand_name,
270 scope,
271 client_id,
272 encrypted_client_secret,
273 token_endpoint_signing_alg,
274 token_endpoint_auth_method,
275 id_token_signed_response_alg,
276 fetch_userinfo,
277 userinfo_signed_response_alg,
278 created_at,
279 disabled_at,
280 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
281 jwks_uri_override,
282 authorization_endpoint_override,
283 token_endpoint_override,
284 userinfo_endpoint_override,
285 discovery_mode,
286 pkce_mode,
287 response_mode,
288 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
289 forward_login_hint,
290 on_backchannel_logout
291 FROM upstream_oauth_providers
292 WHERE upstream_oauth_provider_id = $1
293 "#,
294 Uuid::from(id),
295 )
296 .traced()
297 .fetch_optional(&mut *self.conn)
298 .await?;
299
300 let res = res
301 .map(UpstreamOAuthProvider::try_from)
302 .transpose()
303 .map_err(DatabaseError::from)?;
304
305 Ok(res)
306 }
307
308 #[tracing::instrument(
309 name = "db.upstream_oauth_provider.add",
310 skip_all,
311 fields(
312 db.query.text,
313 upstream_oauth_provider.id,
314 upstream_oauth_provider.issuer = params.issuer,
315 upstream_oauth_provider.client_id = %params.client_id,
316 ),
317 err,
318 )]
319 async fn add(
320 &mut self,
321 rng: &mut (dyn RngCore + Send),
322 clock: &dyn Clock,
323 params: UpstreamOAuthProviderParams,
324 ) -> Result<UpstreamOAuthProvider, Self::Error> {
325 let created_at = clock.now();
326 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
327 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
328
329 sqlx::query!(
330 r#"
331 INSERT INTO upstream_oauth_providers (
332 upstream_oauth_provider_id,
333 issuer,
334 human_name,
335 brand_name,
336 scope,
337 token_endpoint_auth_method,
338 token_endpoint_signing_alg,
339 id_token_signed_response_alg,
340 fetch_userinfo,
341 userinfo_signed_response_alg,
342 client_id,
343 encrypted_client_secret,
344 claims_imports,
345 authorization_endpoint_override,
346 token_endpoint_override,
347 userinfo_endpoint_override,
348 jwks_uri_override,
349 discovery_mode,
350 pkce_mode,
351 response_mode,
352 forward_login_hint,
353 on_backchannel_logout,
354 created_at
355 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
356 $12, $13, $14, $15, $16, $17, $18, $19, $20,
357 $21, $22, $23)
358 "#,
359 Uuid::from(id),
360 params.issuer.as_deref(),
361 params.human_name.as_deref(),
362 params.brand_name.as_deref(),
363 params.scope.to_string(),
364 params.token_endpoint_auth_method.to_string(),
365 params
366 .token_endpoint_signing_alg
367 .as_ref()
368 .map(ToString::to_string),
369 params.id_token_signed_response_alg.to_string(),
370 params.fetch_userinfo,
371 params
372 .userinfo_signed_response_alg
373 .as_ref()
374 .map(ToString::to_string),
375 ¶ms.client_id,
376 params.encrypted_client_secret.as_deref(),
377 Json(¶ms.claims_imports) as _,
378 params
379 .authorization_endpoint_override
380 .as_ref()
381 .map(ToString::to_string),
382 params
383 .token_endpoint_override
384 .as_ref()
385 .map(ToString::to_string),
386 params
387 .userinfo_endpoint_override
388 .as_ref()
389 .map(ToString::to_string),
390 params.jwks_uri_override.as_ref().map(ToString::to_string),
391 params.discovery_mode.as_str(),
392 params.pkce_mode.as_str(),
393 params.response_mode.as_ref().map(ToString::to_string),
394 params.forward_login_hint,
395 params.on_backchannel_logout.as_str(),
396 created_at,
397 )
398 .traced()
399 .execute(&mut *self.conn)
400 .await?;
401
402 Ok(UpstreamOAuthProvider {
403 id,
404 issuer: params.issuer,
405 human_name: params.human_name,
406 brand_name: params.brand_name,
407 scope: params.scope,
408 client_id: params.client_id,
409 encrypted_client_secret: params.encrypted_client_secret,
410 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
411 token_endpoint_auth_method: params.token_endpoint_auth_method,
412 id_token_signed_response_alg: params.id_token_signed_response_alg,
413 fetch_userinfo: params.fetch_userinfo,
414 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
415 created_at,
416 disabled_at: None,
417 claims_imports: params.claims_imports,
418 authorization_endpoint_override: params.authorization_endpoint_override,
419 token_endpoint_override: params.token_endpoint_override,
420 userinfo_endpoint_override: params.userinfo_endpoint_override,
421 jwks_uri_override: params.jwks_uri_override,
422 discovery_mode: params.discovery_mode,
423 pkce_mode: params.pkce_mode,
424 response_mode: params.response_mode,
425 additional_authorization_parameters: params.additional_authorization_parameters,
426 on_backchannel_logout: params.on_backchannel_logout,
427 forward_login_hint: params.forward_login_hint,
428 })
429 }
430
431 #[tracing::instrument(
432 name = "db.upstream_oauth_provider.delete_by_id",
433 skip_all,
434 fields(
435 db.query.text,
436 upstream_oauth_provider.id = %id,
437 ),
438 err,
439 )]
440 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
441 {
444 let span = info_span!(
445 "db.oauth2_client.delete_by_id.authorization_sessions",
446 upstream_oauth_provider.id = %id,
447 { DB_QUERY_TEXT } = tracing::field::Empty,
448 );
449 sqlx::query!(
450 r#"
451 DELETE FROM upstream_oauth_authorization_sessions
452 WHERE upstream_oauth_provider_id = $1
453 "#,
454 Uuid::from(id),
455 )
456 .record(&span)
457 .execute(&mut *self.conn)
458 .instrument(span)
459 .await?;
460 }
461
462 {
465 let span = info_span!(
466 "db.oauth2_client.delete_by_id.links",
467 upstream_oauth_provider.id = %id,
468 { DB_QUERY_TEXT } = tracing::field::Empty,
469 );
470 sqlx::query!(
471 r#"
472 DELETE FROM upstream_oauth_links
473 WHERE upstream_oauth_provider_id = $1
474 "#,
475 Uuid::from(id),
476 )
477 .record(&span)
478 .execute(&mut *self.conn)
479 .instrument(span)
480 .await?;
481 }
482
483 let res = sqlx::query!(
484 r#"
485 DELETE FROM upstream_oauth_providers
486 WHERE upstream_oauth_provider_id = $1
487 "#,
488 Uuid::from(id),
489 )
490 .traced()
491 .execute(&mut *self.conn)
492 .await?;
493
494 DatabaseError::ensure_affected_rows(&res, 1)
495 }
496
497 #[tracing::instrument(
498 name = "db.upstream_oauth_provider.add",
499 skip_all,
500 fields(
501 db.query.text,
502 upstream_oauth_provider.id = %id,
503 upstream_oauth_provider.issuer = params.issuer,
504 upstream_oauth_provider.client_id = %params.client_id,
505 ),
506 err,
507 )]
508 async fn upsert(
509 &mut self,
510 clock: &dyn Clock,
511 id: Ulid,
512 params: UpstreamOAuthProviderParams,
513 ) -> Result<UpstreamOAuthProvider, Self::Error> {
514 let created_at = clock.now();
515
516 let created_at = sqlx::query_scalar!(
517 r#"
518 INSERT INTO upstream_oauth_providers (
519 upstream_oauth_provider_id,
520 issuer,
521 human_name,
522 brand_name,
523 scope,
524 token_endpoint_auth_method,
525 token_endpoint_signing_alg,
526 id_token_signed_response_alg,
527 fetch_userinfo,
528 userinfo_signed_response_alg,
529 client_id,
530 encrypted_client_secret,
531 claims_imports,
532 authorization_endpoint_override,
533 token_endpoint_override,
534 userinfo_endpoint_override,
535 jwks_uri_override,
536 discovery_mode,
537 pkce_mode,
538 response_mode,
539 additional_parameters,
540 forward_login_hint,
541 ui_order,
542 on_backchannel_logout,
543 created_at
544 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
545 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
546 $21, $22, $23, $24, $25)
547 ON CONFLICT (upstream_oauth_provider_id)
548 DO UPDATE
549 SET
550 issuer = EXCLUDED.issuer,
551 human_name = EXCLUDED.human_name,
552 brand_name = EXCLUDED.brand_name,
553 scope = EXCLUDED.scope,
554 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
555 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
556 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
557 fetch_userinfo = EXCLUDED.fetch_userinfo,
558 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
559 disabled_at = NULL,
560 client_id = EXCLUDED.client_id,
561 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
562 claims_imports = EXCLUDED.claims_imports,
563 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
564 token_endpoint_override = EXCLUDED.token_endpoint_override,
565 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
566 jwks_uri_override = EXCLUDED.jwks_uri_override,
567 discovery_mode = EXCLUDED.discovery_mode,
568 pkce_mode = EXCLUDED.pkce_mode,
569 response_mode = EXCLUDED.response_mode,
570 additional_parameters = EXCLUDED.additional_parameters,
571 forward_login_hint = EXCLUDED.forward_login_hint,
572 ui_order = EXCLUDED.ui_order,
573 on_backchannel_logout = EXCLUDED.on_backchannel_logout
574 RETURNING created_at
575 "#,
576 Uuid::from(id),
577 params.issuer.as_deref(),
578 params.human_name.as_deref(),
579 params.brand_name.as_deref(),
580 params.scope.to_string(),
581 params.token_endpoint_auth_method.to_string(),
582 params
583 .token_endpoint_signing_alg
584 .as_ref()
585 .map(ToString::to_string),
586 params.id_token_signed_response_alg.to_string(),
587 params.fetch_userinfo,
588 params
589 .userinfo_signed_response_alg
590 .as_ref()
591 .map(ToString::to_string),
592 ¶ms.client_id,
593 params.encrypted_client_secret.as_deref(),
594 Json(¶ms.claims_imports) as _,
595 params
596 .authorization_endpoint_override
597 .as_ref()
598 .map(ToString::to_string),
599 params
600 .token_endpoint_override
601 .as_ref()
602 .map(ToString::to_string),
603 params
604 .userinfo_endpoint_override
605 .as_ref()
606 .map(ToString::to_string),
607 params.jwks_uri_override.as_ref().map(ToString::to_string),
608 params.discovery_mode.as_str(),
609 params.pkce_mode.as_str(),
610 params.response_mode.as_ref().map(ToString::to_string),
611 Json(¶ms.additional_authorization_parameters) as _,
612 params.forward_login_hint,
613 params.ui_order,
614 params.on_backchannel_logout.as_str(),
615 created_at,
616 )
617 .traced()
618 .fetch_one(&mut *self.conn)
619 .await?;
620
621 Ok(UpstreamOAuthProvider {
622 id,
623 issuer: params.issuer,
624 human_name: params.human_name,
625 brand_name: params.brand_name,
626 scope: params.scope,
627 client_id: params.client_id,
628 encrypted_client_secret: params.encrypted_client_secret,
629 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
630 token_endpoint_auth_method: params.token_endpoint_auth_method,
631 id_token_signed_response_alg: params.id_token_signed_response_alg,
632 fetch_userinfo: params.fetch_userinfo,
633 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
634 created_at,
635 disabled_at: None,
636 claims_imports: params.claims_imports,
637 authorization_endpoint_override: params.authorization_endpoint_override,
638 token_endpoint_override: params.token_endpoint_override,
639 userinfo_endpoint_override: params.userinfo_endpoint_override,
640 jwks_uri_override: params.jwks_uri_override,
641 discovery_mode: params.discovery_mode,
642 pkce_mode: params.pkce_mode,
643 response_mode: params.response_mode,
644 additional_authorization_parameters: params.additional_authorization_parameters,
645 forward_login_hint: params.forward_login_hint,
646 on_backchannel_logout: params.on_backchannel_logout,
647 })
648 }
649
650 #[tracing::instrument(
651 name = "db.upstream_oauth_provider.disable",
652 skip_all,
653 fields(
654 db.query.text,
655 %upstream_oauth_provider.id,
656 ),
657 err,
658 )]
659 async fn disable(
660 &mut self,
661 clock: &dyn Clock,
662 mut upstream_oauth_provider: UpstreamOAuthProvider,
663 ) -> Result<UpstreamOAuthProvider, Self::Error> {
664 let disabled_at = clock.now();
665 let res = sqlx::query!(
666 r#"
667 UPDATE upstream_oauth_providers
668 SET disabled_at = $2
669 WHERE upstream_oauth_provider_id = $1
670 "#,
671 Uuid::from(upstream_oauth_provider.id),
672 disabled_at,
673 )
674 .traced()
675 .execute(&mut *self.conn)
676 .await?;
677
678 DatabaseError::ensure_affected_rows(&res, 1)?;
679
680 upstream_oauth_provider.disabled_at = Some(disabled_at);
681
682 Ok(upstream_oauth_provider)
683 }
684
685 #[tracing::instrument(
686 name = "db.upstream_oauth_provider.list",
687 skip_all,
688 fields(
689 db.query.text,
690 ),
691 err,
692 )]
693 async fn list(
694 &mut self,
695 filter: UpstreamOAuthProviderFilter<'_>,
696 pagination: Pagination,
697 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
698 let (sql, arguments) = Query::select()
699 .expr_as(
700 Expr::col((
701 UpstreamOAuthProviders::Table,
702 UpstreamOAuthProviders::UpstreamOAuthProviderId,
703 )),
704 ProviderLookupIden::UpstreamOauthProviderId,
705 )
706 .expr_as(
707 Expr::col((
708 UpstreamOAuthProviders::Table,
709 UpstreamOAuthProviders::Issuer,
710 )),
711 ProviderLookupIden::Issuer,
712 )
713 .expr_as(
714 Expr::col((
715 UpstreamOAuthProviders::Table,
716 UpstreamOAuthProviders::HumanName,
717 )),
718 ProviderLookupIden::HumanName,
719 )
720 .expr_as(
721 Expr::col((
722 UpstreamOAuthProviders::Table,
723 UpstreamOAuthProviders::BrandName,
724 )),
725 ProviderLookupIden::BrandName,
726 )
727 .expr_as(
728 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
729 ProviderLookupIden::Scope,
730 )
731 .expr_as(
732 Expr::col((
733 UpstreamOAuthProviders::Table,
734 UpstreamOAuthProviders::ClientId,
735 )),
736 ProviderLookupIden::ClientId,
737 )
738 .expr_as(
739 Expr::col((
740 UpstreamOAuthProviders::Table,
741 UpstreamOAuthProviders::EncryptedClientSecret,
742 )),
743 ProviderLookupIden::EncryptedClientSecret,
744 )
745 .expr_as(
746 Expr::col((
747 UpstreamOAuthProviders::Table,
748 UpstreamOAuthProviders::TokenEndpointSigningAlg,
749 )),
750 ProviderLookupIden::TokenEndpointSigningAlg,
751 )
752 .expr_as(
753 Expr::col((
754 UpstreamOAuthProviders::Table,
755 UpstreamOAuthProviders::TokenEndpointAuthMethod,
756 )),
757 ProviderLookupIden::TokenEndpointAuthMethod,
758 )
759 .expr_as(
760 Expr::col((
761 UpstreamOAuthProviders::Table,
762 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
763 )),
764 ProviderLookupIden::IdTokenSignedResponseAlg,
765 )
766 .expr_as(
767 Expr::col((
768 UpstreamOAuthProviders::Table,
769 UpstreamOAuthProviders::FetchUserinfo,
770 )),
771 ProviderLookupIden::FetchUserinfo,
772 )
773 .expr_as(
774 Expr::col((
775 UpstreamOAuthProviders::Table,
776 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
777 )),
778 ProviderLookupIden::UserinfoSignedResponseAlg,
779 )
780 .expr_as(
781 Expr::col((
782 UpstreamOAuthProviders::Table,
783 UpstreamOAuthProviders::CreatedAt,
784 )),
785 ProviderLookupIden::CreatedAt,
786 )
787 .expr_as(
788 Expr::col((
789 UpstreamOAuthProviders::Table,
790 UpstreamOAuthProviders::DisabledAt,
791 )),
792 ProviderLookupIden::DisabledAt,
793 )
794 .expr_as(
795 Expr::col((
796 UpstreamOAuthProviders::Table,
797 UpstreamOAuthProviders::ClaimsImports,
798 )),
799 ProviderLookupIden::ClaimsImports,
800 )
801 .expr_as(
802 Expr::col((
803 UpstreamOAuthProviders::Table,
804 UpstreamOAuthProviders::JwksUriOverride,
805 )),
806 ProviderLookupIden::JwksUriOverride,
807 )
808 .expr_as(
809 Expr::col((
810 UpstreamOAuthProviders::Table,
811 UpstreamOAuthProviders::TokenEndpointOverride,
812 )),
813 ProviderLookupIden::TokenEndpointOverride,
814 )
815 .expr_as(
816 Expr::col((
817 UpstreamOAuthProviders::Table,
818 UpstreamOAuthProviders::AuthorizationEndpointOverride,
819 )),
820 ProviderLookupIden::AuthorizationEndpointOverride,
821 )
822 .expr_as(
823 Expr::col((
824 UpstreamOAuthProviders::Table,
825 UpstreamOAuthProviders::UserinfoEndpointOverride,
826 )),
827 ProviderLookupIden::UserinfoEndpointOverride,
828 )
829 .expr_as(
830 Expr::col((
831 UpstreamOAuthProviders::Table,
832 UpstreamOAuthProviders::DiscoveryMode,
833 )),
834 ProviderLookupIden::DiscoveryMode,
835 )
836 .expr_as(
837 Expr::col((
838 UpstreamOAuthProviders::Table,
839 UpstreamOAuthProviders::PkceMode,
840 )),
841 ProviderLookupIden::PkceMode,
842 )
843 .expr_as(
844 Expr::col((
845 UpstreamOAuthProviders::Table,
846 UpstreamOAuthProviders::ResponseMode,
847 )),
848 ProviderLookupIden::ResponseMode,
849 )
850 .expr_as(
851 Expr::col((
852 UpstreamOAuthProviders::Table,
853 UpstreamOAuthProviders::AdditionalParameters,
854 )),
855 ProviderLookupIden::AdditionalParameters,
856 )
857 .expr_as(
858 Expr::col((
859 UpstreamOAuthProviders::Table,
860 UpstreamOAuthProviders::ForwardLoginHint,
861 )),
862 ProviderLookupIden::ForwardLoginHint,
863 )
864 .expr_as(
865 Expr::col((
866 UpstreamOAuthProviders::Table,
867 UpstreamOAuthProviders::OnBackchannelLogout,
868 )),
869 ProviderLookupIden::OnBackchannelLogout,
870 )
871 .from(UpstreamOAuthProviders::Table)
872 .apply_filter(filter)
873 .generate_pagination(
874 (
875 UpstreamOAuthProviders::Table,
876 UpstreamOAuthProviders::UpstreamOAuthProviderId,
877 ),
878 pagination,
879 )
880 .build_sqlx(PostgresQueryBuilder);
881
882 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
883 .traced()
884 .fetch_all(&mut *self.conn)
885 .await?;
886
887 let page = pagination
888 .process(edges)
889 .try_map(UpstreamOAuthProvider::try_from)?;
890
891 return Ok(page);
892 }
893
894 #[tracing::instrument(
895 name = "db.upstream_oauth_provider.count",
896 skip_all,
897 fields(
898 db.query.text,
899 ),
900 err,
901 )]
902 async fn count(
903 &mut self,
904 filter: UpstreamOAuthProviderFilter<'_>,
905 ) -> Result<usize, Self::Error> {
906 let (sql, arguments) = Query::select()
907 .expr(
908 Expr::col((
909 UpstreamOAuthProviders::Table,
910 UpstreamOAuthProviders::UpstreamOAuthProviderId,
911 ))
912 .count(),
913 )
914 .from(UpstreamOAuthProviders::Table)
915 .apply_filter(filter)
916 .build_sqlx(PostgresQueryBuilder);
917
918 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
919 .traced()
920 .fetch_one(&mut *self.conn)
921 .await?;
922
923 count
924 .try_into()
925 .map_err(DatabaseError::to_invalid_operation)
926 }
927
928 #[tracing::instrument(
929 name = "db.upstream_oauth_provider.all_enabled",
930 skip_all,
931 fields(
932 db.query.text,
933 ),
934 err,
935 )]
936 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
937 let res = sqlx::query_as!(
938 ProviderLookup,
939 r#"
940 SELECT
941 upstream_oauth_provider_id,
942 issuer,
943 human_name,
944 brand_name,
945 scope,
946 client_id,
947 encrypted_client_secret,
948 token_endpoint_signing_alg,
949 token_endpoint_auth_method,
950 id_token_signed_response_alg,
951 fetch_userinfo,
952 userinfo_signed_response_alg,
953 created_at,
954 disabled_at,
955 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
956 jwks_uri_override,
957 authorization_endpoint_override,
958 token_endpoint_override,
959 userinfo_endpoint_override,
960 discovery_mode,
961 pkce_mode,
962 response_mode,
963 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
964 forward_login_hint,
965 on_backchannel_logout
966 FROM upstream_oauth_providers
967 WHERE disabled_at IS NULL
968 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
969 "#,
970 )
971 .traced()
972 .fetch_all(&mut *self.conn)
973 .await?;
974
975 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
976 Ok(res?)
977 }
978}