1mod link;
11mod provider;
12mod session;
13
14pub use self::{
15 link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16 session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21 use chrono::Duration;
22 use mas_data_model::{
23 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24 UpstreamOAuthProviderTokenAuthMethod,
25 };
26 use mas_iana::jose::JsonWebSignatureAlg;
27 use mas_storage::{
28 Pagination, RepositoryAccess,
29 clock::MockClock,
30 upstream_oauth2::{
31 UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
32 UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
33 UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
34 },
35 user::UserRepository,
36 };
37 use oauth2_types::scope::{OPENID, Scope};
38 use rand::SeedableRng;
39 use sqlx::PgPool;
40
41 use crate::PgRepository;
42
43 #[sqlx::test(migrator = "crate::MIGRATOR")]
44 async fn test_repository(pool: PgPool) {
45 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
46 let clock = MockClock::default();
47 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
48
49 let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
51 assert!(all_providers.is_empty());
52
53 let provider = repo
55 .upstream_oauth_provider()
56 .add(
57 &mut rng,
58 &clock,
59 UpstreamOAuthProviderParams {
60 issuer: Some("https://example.com/".to_owned()),
61 human_name: None,
62 brand_name: None,
63 scope: Scope::from_iter([OPENID]),
64 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
65 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
66 fetch_userinfo: false,
67 userinfo_signed_response_alg: None,
68 token_endpoint_signing_alg: None,
69 client_id: "client-id".to_owned(),
70 encrypted_client_secret: None,
71 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
72 token_endpoint_override: None,
73 authorization_endpoint_override: None,
74 userinfo_endpoint_override: None,
75 jwks_uri_override: None,
76 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
77 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
78 response_mode: None,
79 additional_authorization_parameters: Vec::new(),
80 forward_login_hint: false,
81 ui_order: 0,
82 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
83 },
84 )
85 .await
86 .unwrap();
87
88 let provider = repo
90 .upstream_oauth_provider()
91 .lookup(provider.id)
92 .await
93 .unwrap()
94 .expect("provider to be found in the database");
95 assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
96 assert_eq!(provider.client_id, "client-id");
97
98 let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
100 assert_eq!(providers.len(), 1);
101 assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
102 assert_eq!(providers[0].client_id, "client-id");
103
104 let session = repo
106 .upstream_oauth_session()
107 .add(
108 &mut rng,
109 &clock,
110 &provider,
111 "some-state".to_owned(),
112 None,
113 Some("some-nonce".to_owned()),
114 )
115 .await
116 .unwrap();
117
118 let session = repo
120 .upstream_oauth_session()
121 .lookup(session.id)
122 .await
123 .unwrap()
124 .expect("session to be found in the database");
125 assert_eq!(session.provider_id, provider.id);
126 assert_eq!(session.link_id(), None);
127 assert!(session.is_pending());
128 assert!(!session.is_completed());
129 assert!(!session.is_consumed());
130
131 let link = repo
133 .upstream_oauth_link()
134 .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
135 .await
136 .unwrap();
137
138 repo.upstream_oauth_link()
140 .lookup(link.id)
141 .await
142 .unwrap()
143 .expect("link to be found in database");
144
145 let link = repo
147 .upstream_oauth_link()
148 .find_by_subject(&provider, "a-subject")
149 .await
150 .unwrap()
151 .expect("link to be found in database");
152 assert_eq!(link.subject, "a-subject");
153 assert_eq!(link.provider_id, provider.id);
154
155 let session = repo
156 .upstream_oauth_session()
157 .complete_with_link(&clock, session, &link, None, None, None, None)
158 .await
159 .unwrap();
160 let session = repo
162 .upstream_oauth_session()
163 .lookup(session.id)
164 .await
165 .unwrap()
166 .expect("session to be found in the database");
167 assert!(session.is_completed());
168 assert!(!session.is_consumed());
169 assert_eq!(session.link_id(), Some(link.id));
170
171 let session = repo
172 .upstream_oauth_session()
173 .consume(&clock, session)
174 .await
175 .unwrap();
176 let session = repo
178 .upstream_oauth_session()
179 .lookup(session.id)
180 .await
181 .unwrap()
182 .expect("session to be found in the database");
183 assert!(session.is_consumed());
184
185 let user = repo
186 .user()
187 .add(&mut rng, &clock, "john".to_owned())
188 .await
189 .unwrap();
190 repo.upstream_oauth_link()
191 .associate_to_user(&link, &user)
192 .await
193 .unwrap();
194
195 let filter = UpstreamOAuthLinkFilter::new()
197 .for_user(&user)
198 .for_provider(&provider)
199 .for_subject("a-subject")
200 .enabled_providers_only();
201
202 let links = repo
203 .upstream_oauth_link()
204 .list(filter, Pagination::first(10))
205 .await
206 .unwrap();
207 assert!(!links.has_previous_page);
208 assert!(!links.has_next_page);
209 assert_eq!(links.edges.len(), 1);
210 assert_eq!(links.edges[0].id, link.id);
211 assert_eq!(links.edges[0].user_id, Some(user.id));
212
213 assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
214
215 assert_eq!(
217 repo.upstream_oauth_provider()
218 .count(UpstreamOAuthProviderFilter::new())
219 .await
220 .unwrap(),
221 1
222 );
223 assert_eq!(
224 repo.upstream_oauth_provider()
225 .count(UpstreamOAuthProviderFilter::new().enabled_only())
226 .await
227 .unwrap(),
228 1
229 );
230 assert_eq!(
231 repo.upstream_oauth_provider()
232 .count(UpstreamOAuthProviderFilter::new().disabled_only())
233 .await
234 .unwrap(),
235 0
236 );
237
238 repo.upstream_oauth_provider()
240 .disable(&clock, provider.clone())
241 .await
242 .unwrap();
243
244 assert_eq!(
246 repo.upstream_oauth_provider()
247 .count(UpstreamOAuthProviderFilter::new())
248 .await
249 .unwrap(),
250 1
251 );
252 assert_eq!(
253 repo.upstream_oauth_provider()
254 .count(UpstreamOAuthProviderFilter::new().enabled_only())
255 .await
256 .unwrap(),
257 0
258 );
259 assert_eq!(
260 repo.upstream_oauth_provider()
261 .count(UpstreamOAuthProviderFilter::new().disabled_only())
262 .await
263 .unwrap(),
264 1
265 );
266
267 let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
269
270 let session_count = repo
272 .upstream_oauth_session()
273 .count(session_filter)
274 .await
275 .unwrap();
276 assert_eq!(session_count, 1);
277
278 let session_page = repo
280 .upstream_oauth_session()
281 .list(session_filter, Pagination::first(10))
282 .await
283 .unwrap();
284
285 assert_eq!(session_page.edges.len(), 1);
286 assert_eq!(session_page.edges[0].id, session.id);
287 assert!(!session_page.has_next_page);
288 assert!(!session_page.has_previous_page);
289
290 repo.upstream_oauth_provider()
292 .delete(provider)
293 .await
294 .unwrap();
295 assert_eq!(
296 repo.upstream_oauth_provider()
297 .count(UpstreamOAuthProviderFilter::new())
298 .await
299 .unwrap(),
300 0
301 );
302 }
303
304 #[sqlx::test(migrator = "crate::MIGRATOR")]
307 async fn test_provider_repository_pagination(pool: PgPool) {
308 let scope = Scope::from_iter([OPENID]);
309
310 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
311 let clock = MockClock::default();
312 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
313
314 let filter = UpstreamOAuthProviderFilter::new();
315
316 assert_eq!(
318 repo.upstream_oauth_provider().count(filter).await.unwrap(),
319 0
320 );
321
322 let mut ids = Vec::with_capacity(20);
323 for idx in 0..20 {
325 let client_id = format!("client-{idx}");
326 let provider = repo
327 .upstream_oauth_provider()
328 .add(
329 &mut rng,
330 &clock,
331 UpstreamOAuthProviderParams {
332 issuer: None,
333 human_name: None,
334 brand_name: None,
335 scope: scope.clone(),
336 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
337 fetch_userinfo: false,
338 userinfo_signed_response_alg: None,
339 token_endpoint_signing_alg: None,
340 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
341 client_id,
342 encrypted_client_secret: None,
343 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
344 token_endpoint_override: None,
345 authorization_endpoint_override: None,
346 userinfo_endpoint_override: None,
347 jwks_uri_override: None,
348 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
349 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
350 response_mode: None,
351 additional_authorization_parameters: Vec::new(),
352 forward_login_hint: false,
353 ui_order: 0,
354 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
355 },
356 )
357 .await
358 .unwrap();
359 ids.push(provider.id);
360 clock.advance(Duration::microseconds(10 * 1000 * 1000));
361 }
362
363 assert_eq!(
365 repo.upstream_oauth_provider().count(filter).await.unwrap(),
366 20
367 );
368
369 let page = repo
371 .upstream_oauth_provider()
372 .list(filter, Pagination::first(10))
373 .await
374 .unwrap();
375
376 assert!(page.has_next_page);
378 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
379 assert_eq!(&edge_ids, &ids[..10]);
380
381 let other_page = repo
384 .upstream_oauth_provider()
385 .list(filter.enabled_only(), Pagination::first(10))
386 .await
387 .unwrap();
388
389 assert_eq!(page, other_page);
390
391 let page = repo
393 .upstream_oauth_provider()
394 .list(filter, Pagination::first(10).after(ids[9]))
395 .await
396 .unwrap();
397
398 assert!(!page.has_next_page);
400 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
401 assert_eq!(&edge_ids, &ids[10..]);
402
403 let page = repo
405 .upstream_oauth_provider()
406 .list(filter, Pagination::last(10))
407 .await
408 .unwrap();
409
410 assert!(page.has_previous_page);
412 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
413 assert_eq!(&edge_ids, &ids[10..]);
414
415 let page = repo
417 .upstream_oauth_provider()
418 .list(filter, Pagination::last(10).before(ids[10]))
419 .await
420 .unwrap();
421
422 assert!(!page.has_previous_page);
424 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
425 assert_eq!(&edge_ids, &ids[..10]);
426
427 let page = repo
429 .upstream_oauth_provider()
430 .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
431 .await
432 .unwrap();
433
434 assert!(!page.has_next_page);
436 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
437 assert_eq!(&edge_ids, &ids[6..8]);
438
439 assert!(
441 repo.upstream_oauth_provider()
442 .list(
443 UpstreamOAuthProviderFilter::new().disabled_only(),
444 Pagination::first(1)
445 )
446 .await
447 .unwrap()
448 .edges
449 .is_empty()
450 );
451 }
452
453 #[sqlx::test(migrator = "crate::MIGRATOR")]
456 async fn test_session_repository_pagination(pool: PgPool) {
457 let scope = Scope::from_iter([OPENID]);
458
459 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
460 let clock = MockClock::default();
461 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
462
463 let provider = repo
465 .upstream_oauth_provider()
466 .add(
467 &mut rng,
468 &clock,
469 UpstreamOAuthProviderParams {
470 issuer: Some("https://example.com/".to_owned()),
471 human_name: None,
472 brand_name: None,
473 scope,
474 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
475 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
476 fetch_userinfo: false,
477 userinfo_signed_response_alg: None,
478 token_endpoint_signing_alg: None,
479 client_id: "client-id".to_owned(),
480 encrypted_client_secret: None,
481 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
482 token_endpoint_override: None,
483 authorization_endpoint_override: None,
484 userinfo_endpoint_override: None,
485 jwks_uri_override: None,
486 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
487 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
488 response_mode: None,
489 additional_authorization_parameters: Vec::new(),
490 forward_login_hint: false,
491 ui_order: 0,
492 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
493 },
494 )
495 .await
496 .unwrap();
497
498 let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
499
500 assert_eq!(
502 repo.upstream_oauth_session().count(filter).await.unwrap(),
503 0
504 );
505
506 let mut links = Vec::with_capacity(3);
507 for subject in ["alice", "bob", "charlie"] {
508 let link = repo
509 .upstream_oauth_link()
510 .add(&mut rng, &clock, &provider, subject.to_owned(), None)
511 .await
512 .unwrap();
513 links.push(link);
514 }
515
516 let mut ids = Vec::with_capacity(20);
517 let sids = ["one", "two"].into_iter().cycle();
518 for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
520 let state = format!("state-{idx}");
521 let session = repo
522 .upstream_oauth_session()
523 .add(&mut rng, &clock, &provider, state, None, None)
524 .await
525 .unwrap();
526 let id_token_claims = serde_json::json!({
527 "sub": link.subject,
528 "sid": sid,
529 "aud": provider.client_id,
530 "iss": "https://example.com/",
531 });
532 let session = repo
533 .upstream_oauth_session()
534 .complete_with_link(
535 &clock,
536 session,
537 link,
538 None,
539 Some(id_token_claims),
540 None,
541 None,
542 )
543 .await
544 .unwrap();
545 ids.push(session.id);
546 clock.advance(Duration::microseconds(10 * 1000 * 1000));
547 }
548
549 assert_eq!(
551 repo.upstream_oauth_session().count(filter).await.unwrap(),
552 20
553 );
554
555 let page = repo
557 .upstream_oauth_session()
558 .list(filter, Pagination::first(10))
559 .await
560 .unwrap();
561
562 assert!(page.has_next_page);
564 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
565 assert_eq!(&edge_ids, &ids[..10]);
566
567 let page = repo
569 .upstream_oauth_session()
570 .list(filter, Pagination::first(10).after(ids[9]))
571 .await
572 .unwrap();
573
574 assert!(!page.has_next_page);
576 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
577 assert_eq!(&edge_ids, &ids[10..]);
578
579 let page = repo
581 .upstream_oauth_session()
582 .list(filter, Pagination::last(10))
583 .await
584 .unwrap();
585
586 assert!(page.has_previous_page);
588 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
589 assert_eq!(&edge_ids, &ids[10..]);
590
591 let page = repo
593 .upstream_oauth_session()
594 .list(filter, Pagination::last(10).before(ids[10]))
595 .await
596 .unwrap();
597
598 assert!(!page.has_previous_page);
600 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
601 assert_eq!(&edge_ids, &ids[..10]);
602
603 let page = repo
605 .upstream_oauth_session()
606 .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
607 .await
608 .unwrap();
609
610 assert!(!page.has_next_page);
612 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.id).collect();
613 assert_eq!(&edge_ids, &ids[6..11]);
614
615 assert_eq!(
617 repo.upstream_oauth_session()
618 .count(filter.with_sub_claim("alice").with_sid_claim("one"))
619 .await
620 .unwrap(),
621 4
622 );
623 assert_eq!(
624 repo.upstream_oauth_session()
625 .count(filter.with_sub_claim("bob").with_sid_claim("two"))
626 .await
627 .unwrap(),
628 4
629 );
630
631 let page = repo
632 .upstream_oauth_session()
633 .list(
634 filter.with_sub_claim("alice").with_sid_claim("one"),
635 Pagination::first(10),
636 )
637 .await
638 .unwrap();
639 assert_eq!(page.edges.len(), 4);
640 for edge in page.edges {
641 assert_eq!(
642 edge.id_token_claims().unwrap().get("sub").unwrap().as_str(),
643 Some("alice")
644 );
645 assert_eq!(
646 edge.id_token_claims().unwrap().get("sid").unwrap().as_str(),
647 Some("one")
648 );
649 }
650 }
651}