mas_storage/upstream_oauth2/provider.rs
1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11 UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
12 UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
13 UpstreamOAuthProviderResponseMode, UpstreamOAuthProviderTokenAuthMethod,
14};
15use mas_iana::jose::JsonWebSignatureAlg;
16use oauth2_types::scope::Scope;
17use rand_core::RngCore;
18use ulid::Ulid;
19use url::Url;
20
21use crate::{Clock, Pagination, pagination::Page, repository_impl};
22
23/// Structure which holds parameters when inserting or updating an upstream
24/// OAuth 2.0 provider
25pub struct UpstreamOAuthProviderParams {
26 /// The OIDC issuer of the provider
27 pub issuer: Option<String>,
28
29 /// A human-readable name for the provider
30 pub human_name: Option<String>,
31
32 /// A brand identifier, e.g. "apple" or "google"
33 pub brand_name: Option<String>,
34
35 /// The scope to request during the authorization flow
36 pub scope: Scope,
37
38 /// The token endpoint authentication method
39 pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
40
41 /// The JWT signing algorithm to use when then `client_secret_jwt` or
42 /// `private_key_jwt` authentication methods are used
43 pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
44
45 /// Expected signature for the JWT payload returned by the token
46 /// authentication endpoint.
47 ///
48 /// Defaults to `RS256`.
49 pub id_token_signed_response_alg: JsonWebSignatureAlg,
50
51 /// Whether to fetch the user profile from the userinfo endpoint,
52 /// or to rely on the data returned in the `id_token` from the
53 /// `token_endpoint`.
54 pub fetch_userinfo: bool,
55
56 /// Expected signature for the JWT payload returned by the userinfo
57 /// endpoint.
58 ///
59 /// If not specified, the response is expected to be an unsigned JSON
60 /// payload. Defaults to `None`.
61 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
62
63 /// The client ID to use when authenticating to the upstream
64 pub client_id: String,
65
66 /// The encrypted client secret to use when authenticating to the upstream
67 pub encrypted_client_secret: Option<String>,
68
69 /// How claims should be imported from the upstream provider
70 pub claims_imports: UpstreamOAuthProviderClaimsImports,
71
72 /// The URL to use as the authorization endpoint. If `None`, the URL will be
73 /// discovered
74 pub authorization_endpoint_override: Option<Url>,
75
76 /// The URL to use as the token endpoint. If `None`, the URL will be
77 /// discovered
78 pub token_endpoint_override: Option<Url>,
79
80 /// The URL to use as the userinfo endpoint. If `None`, the URL will be
81 /// discovered
82 pub userinfo_endpoint_override: Option<Url>,
83
84 /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
85 pub jwks_uri_override: Option<Url>,
86
87 /// How the provider metadata should be discovered
88 pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
89
90 /// How should PKCE be used
91 pub pkce_mode: UpstreamOAuthProviderPkceMode,
92
93 /// What response mode it should ask
94 pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
95
96 /// Additional parameters to include in the authorization request
97 pub additional_authorization_parameters: Vec<(String, String)>,
98
99 /// Whether to forward the login hint to the upstream provider.
100 pub forward_login_hint: bool,
101
102 /// The position of the provider in the UI
103 pub ui_order: i32,
104
105 /// The behavior when receiving a backchannel logout notification
106 pub on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout,
107}
108
109/// Filter parameters for listing upstream OAuth 2.0 providers
110#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
111pub struct UpstreamOAuthProviderFilter<'a> {
112 /// Filter by whether the provider is enabled
113 ///
114 /// If `None`, all providers are returned
115 enabled: Option<bool>,
116
117 _lifetime: PhantomData<&'a ()>,
118}
119
120impl UpstreamOAuthProviderFilter<'_> {
121 /// Create a new [`UpstreamOAuthProviderFilter`] with default values
122 #[must_use]
123 pub fn new() -> Self {
124 Self::default()
125 }
126
127 /// Return only enabled providers
128 #[must_use]
129 pub const fn enabled_only(mut self) -> Self {
130 self.enabled = Some(true);
131 self
132 }
133
134 /// Return only disabled providers
135 #[must_use]
136 pub const fn disabled_only(mut self) -> Self {
137 self.enabled = Some(false);
138 self
139 }
140
141 /// Get the enabled filter
142 ///
143 /// Returns `None` if the filter is not set
144 #[must_use]
145 pub const fn enabled(&self) -> Option<bool> {
146 self.enabled
147 }
148}
149
150/// An [`UpstreamOAuthProviderRepository`] helps interacting with
151/// [`UpstreamOAuthProvider`] saved in the storage backend
152#[async_trait]
153pub trait UpstreamOAuthProviderRepository: Send + Sync {
154 /// The error type returned by the repository
155 type Error;
156
157 /// Lookup an upstream OAuth provider by its ID
158 ///
159 /// Returns `None` if the provider was not found
160 ///
161 /// # Parameters
162 ///
163 /// * `id`: The ID of the provider to lookup
164 ///
165 /// # Errors
166 ///
167 /// Returns [`Self::Error`] if the underlying repository fails
168 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
169
170 /// Add a new upstream OAuth provider
171 ///
172 /// Returns the newly created provider
173 ///
174 /// # Parameters
175 ///
176 /// * `rng`: A random number generator
177 /// * `clock`: The clock used to generate timestamps
178 /// * `params`: The parameters of the provider to add
179 ///
180 /// # Errors
181 ///
182 /// Returns [`Self::Error`] if the underlying repository fails
183 async fn add(
184 &mut self,
185 rng: &mut (dyn RngCore + Send),
186 clock: &dyn Clock,
187 params: UpstreamOAuthProviderParams,
188 ) -> Result<UpstreamOAuthProvider, Self::Error>;
189
190 /// Delete an upstream OAuth provider
191 ///
192 /// # Parameters
193 ///
194 /// * `provider`: The provider to delete
195 ///
196 /// # Errors
197 ///
198 /// Returns [`Self::Error`] if the underlying repository fails
199 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
200 self.delete_by_id(provider.id).await
201 }
202
203 /// Delete an upstream OAuth provider by its ID
204 ///
205 /// # Parameters
206 ///
207 /// * `id`: The ID of the provider to delete
208 ///
209 /// # Errors
210 ///
211 /// Returns [`Self::Error`] if the underlying repository fails
212 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
213
214 /// Insert or update an upstream OAuth provider
215 ///
216 /// # Parameters
217 ///
218 /// * `clock`: The clock used to generate timestamps
219 /// * `id`: The ID of the provider to update
220 /// * `params`: The parameters of the provider to update
221 ///
222 /// # Errors
223 ///
224 /// Returns [`Self::Error`] if the underlying repository fails
225 async fn upsert(
226 &mut self,
227 clock: &dyn Clock,
228 id: Ulid,
229 params: UpstreamOAuthProviderParams,
230 ) -> Result<UpstreamOAuthProvider, Self::Error>;
231
232 /// Disable an upstream OAuth provider
233 ///
234 /// Returns the disabled provider
235 ///
236 /// # Parameters
237 ///
238 /// * `clock`: The clock used to generate timestamps
239 /// * `provider`: The provider to disable
240 ///
241 /// # Errors
242 ///
243 /// Returns [`Self::Error`] if the underlying repository fails
244 async fn disable(
245 &mut self,
246 clock: &dyn Clock,
247 provider: UpstreamOAuthProvider,
248 ) -> Result<UpstreamOAuthProvider, Self::Error>;
249
250 /// List [`UpstreamOAuthProvider`] with the given filter and pagination
251 ///
252 /// # Parameters
253 ///
254 /// * `filter`: The filter to apply
255 /// * `pagination`: The pagination parameters
256 ///
257 /// # Errors
258 ///
259 /// Returns [`Self::Error`] if the underlying repository fails
260 async fn list(
261 &mut self,
262 filter: UpstreamOAuthProviderFilter<'_>,
263 pagination: Pagination,
264 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
265
266 /// Count the number of [`UpstreamOAuthProvider`] with the given filter
267 ///
268 /// # Parameters
269 ///
270 /// * `filter`: The filter to apply
271 ///
272 /// # Errors
273 ///
274 /// Returns [`Self::Error`] if the underlying repository fails
275 async fn count(
276 &mut self,
277 filter: UpstreamOAuthProviderFilter<'_>,
278 ) -> Result<usize, Self::Error>;
279
280 /// Get all enabled upstream OAuth providers
281 ///
282 /// # Errors
283 ///
284 /// Returns [`Self::Error`] if the underlying repository fails
285 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
286}
287
288repository_impl!(UpstreamOAuthProviderRepository:
289 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
290
291 async fn add(
292 &mut self,
293 rng: &mut (dyn RngCore + Send),
294 clock: &dyn Clock,
295 params: UpstreamOAuthProviderParams
296 ) -> Result<UpstreamOAuthProvider, Self::Error>;
297
298 async fn upsert(
299 &mut self,
300 clock: &dyn Clock,
301 id: Ulid,
302 params: UpstreamOAuthProviderParams
303 ) -> Result<UpstreamOAuthProvider, Self::Error>;
304
305 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
306
307 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
308
309 async fn disable(
310 &mut self,
311 clock: &dyn Clock,
312 provider: UpstreamOAuthProvider
313 ) -> Result<UpstreamOAuthProvider, Self::Error>;
314
315 async fn list(
316 &mut self,
317 filter: UpstreamOAuthProviderFilter<'_>,
318 pagination: Pagination
319 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
320
321 async fn count(
322 &mut self,
323 filter: UpstreamOAuthProviderFilter<'_>
324 ) -> Result<usize, Self::Error>;
325
326 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
327);