1use std::collections::HashMap;
8
9use axum::{
10 BoxError, Json,
11 extract::{
12 Form, FromRequest,
13 rejection::{FailedToDeserializeForm, FormRejection},
14 },
15 response::IntoResponse,
16};
17use headers::authorization::{Basic, Bearer, Credentials as _};
18use http::{Request, StatusCode};
19use mas_data_model::{Client, JwksOrJwksUri};
20use mas_http::RequestBuilderExt;
21use mas_iana::oauth::OAuthClientAuthenticationMethod;
22use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
23use mas_keystore::Encrypter;
24use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
25use oauth2_types::errors::{ClientError, ClientErrorCode};
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::Value;
28use thiserror::Error;
29
30use crate::record_error;
31
32static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
33
34#[derive(Deserialize)]
35struct AuthorizedForm<F = ()> {
36 client_id: Option<String>,
37 client_secret: Option<String>,
38 client_assertion_type: Option<String>,
39 client_assertion: Option<String>,
40
41 #[serde(flatten)]
42 inner: F,
43}
44
45#[derive(Debug, PartialEq, Eq)]
46pub enum Credentials {
47 None {
48 client_id: String,
49 },
50 ClientSecretBasic {
51 client_id: String,
52 client_secret: String,
53 },
54 ClientSecretPost {
55 client_id: String,
56 client_secret: String,
57 },
58 ClientAssertionJwtBearer {
59 client_id: String,
60 jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
61 },
62 BearerToken {
63 token: String,
64 },
65}
66
67impl Credentials {
68 #[must_use]
70 pub fn client_id(&self) -> Option<&str> {
71 match self {
72 Credentials::None { client_id }
73 | Credentials::ClientSecretBasic { client_id, .. }
74 | Credentials::ClientSecretPost { client_id, .. }
75 | Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
76 Credentials::BearerToken { .. } => None,
77 }
78 }
79
80 #[must_use]
82 pub fn bearer_token(&self) -> Option<&str> {
83 match self {
84 Credentials::BearerToken { token } => Some(token),
85 _ => None,
86 }
87 }
88
89 pub async fn fetch<E>(
96 &self,
97 repo: &mut impl RepositoryAccess<Error = E>,
98 ) -> Result<Option<Client>, E> {
99 let client_id = match self {
100 Credentials::None { client_id }
101 | Credentials::ClientSecretBasic { client_id, .. }
102 | Credentials::ClientSecretPost { client_id, .. }
103 | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
104 Credentials::BearerToken { .. } => return Ok(None),
105 };
106
107 repo.oauth2_client().find_by_client_id(client_id).await
108 }
109
110 #[tracing::instrument(skip_all)]
116 pub async fn verify(
117 &self,
118 http_client: &reqwest::Client,
119 encrypter: &Encrypter,
120 method: &OAuthClientAuthenticationMethod,
121 client: &Client,
122 ) -> Result<(), CredentialsVerificationError> {
123 match (self, method) {
124 (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
125
126 (
127 Credentials::ClientSecretPost { client_secret, .. },
128 OAuthClientAuthenticationMethod::ClientSecretPost,
129 )
130 | (
131 Credentials::ClientSecretBasic { client_secret, .. },
132 OAuthClientAuthenticationMethod::ClientSecretBasic,
133 ) => {
134 let encrypted_client_secret = client
136 .encrypted_client_secret
137 .as_ref()
138 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
139
140 let decrypted_client_secret = encrypter
141 .decrypt_string(encrypted_client_secret)
142 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
143
144 if client_secret.as_bytes() != decrypted_client_secret {
146 return Err(CredentialsVerificationError::ClientSecretMismatch);
147 }
148 }
149
150 (
151 Credentials::ClientAssertionJwtBearer { jwt, .. },
152 OAuthClientAuthenticationMethod::PrivateKeyJwt,
153 ) => {
154 let jwks = client
156 .jwks
157 .as_ref()
158 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
159
160 let jwks = fetch_jwks(http_client, jwks)
161 .await
162 .map_err(CredentialsVerificationError::JwksFetchFailed)?;
163
164 jwt.verify_with_jwks(&jwks)
165 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
166 }
167
168 (
169 Credentials::ClientAssertionJwtBearer { jwt, .. },
170 OAuthClientAuthenticationMethod::ClientSecretJwt,
171 ) => {
172 let encrypted_client_secret = client
174 .encrypted_client_secret
175 .as_ref()
176 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
177
178 let decrypted_client_secret = encrypter
179 .decrypt_string(encrypted_client_secret)
180 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
181
182 jwt.verify_with_shared_secret(decrypted_client_secret)
183 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
184 }
185
186 (_, _) => {
187 return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
188 }
189 }
190 Ok(())
191 }
192}
193
194async fn fetch_jwks(
195 http_client: &reqwest::Client,
196 jwks: &JwksOrJwksUri,
197) -> Result<PublicJsonWebKeySet, BoxError> {
198 let uri = match jwks {
199 JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
200 JwksOrJwksUri::JwksUri(u) => u,
201 };
202
203 let response = http_client
204 .get(uri.as_str())
205 .send_traced()
206 .await?
207 .error_for_status()?
208 .json()
209 .await?;
210
211 Ok(response)
212}
213
214#[derive(Debug, Error)]
215pub enum CredentialsVerificationError {
216 #[error("failed to decrypt client credentials")]
217 DecryptionError,
218
219 #[error("invalid client configuration")]
220 InvalidClientConfig,
221
222 #[error("client secret did not match")]
223 ClientSecretMismatch,
224
225 #[error("authentication method mismatch")]
226 AuthenticationMethodMismatch,
227
228 #[error("invalid assertion signature")]
229 InvalidAssertionSignature,
230
231 #[error("failed to fetch jwks")]
232 JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
233}
234
235impl CredentialsVerificationError {
236 #[must_use]
238 pub fn is_internal(&self) -> bool {
239 matches!(
240 self,
241 Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
242 )
243 }
244}
245
246#[derive(Debug, PartialEq, Eq)]
247pub struct ClientAuthorization<F = ()> {
248 pub credentials: Credentials,
249 pub form: Option<F>,
250}
251
252impl<F> ClientAuthorization<F> {
253 #[must_use]
255 pub fn client_id(&self) -> Option<&str> {
256 self.credentials.client_id()
257 }
258}
259
260#[derive(Debug, Error)]
261pub enum ClientAuthorizationError {
262 #[error("Invalid Authorization header")]
263 InvalidHeader,
264
265 #[error("Could not deserialize request body")]
266 BadForm(#[source] FailedToDeserializeForm),
267
268 #[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
269 ClientIdMismatch { credential: String, form: String },
270
271 #[error("Unsupported client_assertion_type: {client_assertion_type}")]
272 UnsupportedClientAssertion { client_assertion_type: String },
273
274 #[error("No credentials were presented")]
275 MissingCredentials,
276
277 #[error("Invalid request")]
278 InvalidRequest,
279
280 #[error("Invalid client_assertion")]
281 InvalidAssertion,
282
283 #[error(transparent)]
284 Internal(Box<dyn std::error::Error>),
285}
286
287impl IntoResponse for ClientAuthorizationError {
288 fn into_response(self) -> axum::response::Response {
289 let sentry_event_id = record_error!(self, Self::Internal(_));
290 match &self {
291 ClientAuthorizationError::InvalidHeader => (
292 StatusCode::BAD_REQUEST,
293 sentry_event_id,
294 Json(ClientError::new(
295 ClientErrorCode::InvalidRequest,
296 "Invalid Authorization header",
297 )),
298 ),
299
300 ClientAuthorizationError::BadForm(err) => (
301 StatusCode::BAD_REQUEST,
302 sentry_event_id,
303 Json(
304 ClientError::from(ClientErrorCode::InvalidRequest)
305 .with_description(format!("{err}")),
306 ),
307 ),
308
309 ClientAuthorizationError::ClientIdMismatch { .. } => (
310 StatusCode::BAD_REQUEST,
311 sentry_event_id,
312 Json(
313 ClientError::from(ClientErrorCode::InvalidGrant)
314 .with_description(format!("{self}")),
315 ),
316 ),
317
318 ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
319 StatusCode::BAD_REQUEST,
320 sentry_event_id,
321 Json(
322 ClientError::from(ClientErrorCode::InvalidRequest)
323 .with_description(format!("{self}")),
324 ),
325 ),
326
327 ClientAuthorizationError::MissingCredentials => (
328 StatusCode::BAD_REQUEST,
329 sentry_event_id,
330 Json(ClientError::new(
331 ClientErrorCode::InvalidRequest,
332 "No credentials were presented",
333 )),
334 ),
335
336 ClientAuthorizationError::InvalidRequest => (
337 StatusCode::BAD_REQUEST,
338 sentry_event_id,
339 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
340 ),
341
342 ClientAuthorizationError::InvalidAssertion => (
343 StatusCode::BAD_REQUEST,
344 sentry_event_id,
345 Json(ClientError::new(
346 ClientErrorCode::InvalidRequest,
347 "Invalid client_assertion",
348 )),
349 ),
350
351 ClientAuthorizationError::Internal(e) => (
352 StatusCode::INTERNAL_SERVER_ERROR,
353 sentry_event_id,
354 Json(
355 ClientError::from(ClientErrorCode::ServerError)
356 .with_description(format!("{e}")),
357 ),
358 ),
359 }
360 .into_response()
361 }
362}
363
364impl<S, F> FromRequest<S> for ClientAuthorization<F>
365where
366 F: DeserializeOwned,
367 S: Send + Sync,
368{
369 type Rejection = ClientAuthorizationError;
370
371 #[allow(clippy::too_many_lines)]
372 async fn from_request(
373 req: Request<axum::body::Body>,
374 state: &S,
375 ) -> Result<Self, Self::Rejection> {
376 enum Authorization {
377 Basic(String, String),
378 Bearer(String),
379 }
380
381 let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
384 let bytes = header.as_bytes();
385 if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
386 let Some(decoded) = Basic::decode(header) else {
387 return Err(ClientAuthorizationError::InvalidHeader);
388 };
389
390 Some(Authorization::Basic(
391 decoded.username().to_owned(),
392 decoded.password().to_owned(),
393 ))
394 } else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
395 let Some(decoded) = Bearer::decode(header) else {
396 return Err(ClientAuthorizationError::InvalidHeader);
397 };
398
399 Some(Authorization::Bearer(decoded.token().to_owned()))
400 } else {
401 return Err(ClientAuthorizationError::InvalidHeader);
402 }
403 } else {
404 None
405 };
406
407 let (
409 client_id_from_form,
410 client_secret_from_form,
411 client_assertion_type,
412 client_assertion,
413 form,
414 ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
415 Ok(Form(form)) => (
416 form.client_id,
417 form.client_secret,
418 form.client_assertion_type,
419 form.client_assertion,
420 Some(form.inner),
421 ),
422 Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
424 Err(FormRejection::FailedToDeserializeForm(err)) => {
426 return Err(ClientAuthorizationError::BadForm(err));
427 }
428 Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
430 };
431
432 let credentials = match (
434 authorization,
435 client_id_from_form,
436 client_secret_from_form,
437 client_assertion_type,
438 client_assertion,
439 ) {
440 (
441 Some(Authorization::Basic(client_id, client_secret)),
442 client_id_from_form,
443 None,
444 None,
445 None,
446 ) => {
447 if let Some(client_id_from_form) = client_id_from_form {
448 if client_id != client_id_from_form {
450 return Err(ClientAuthorizationError::ClientIdMismatch {
451 credential: client_id,
452 form: client_id_from_form,
453 });
454 }
455 }
456
457 Credentials::ClientSecretBasic {
458 client_id,
459 client_secret,
460 }
461 }
462
463 (None, Some(client_id), Some(client_secret), None, None) => {
464 Credentials::ClientSecretPost {
466 client_id,
467 client_secret,
468 }
469 }
470
471 (None, Some(client_id), None, None, None) => {
472 Credentials::None { client_id }
474 }
475
476 (
477 None,
478 client_id_from_form,
479 None,
480 Some(client_assertion_type),
481 Some(client_assertion),
482 ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
483 let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
485 .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
486
487 let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
488 client_id.clone()
489 } else {
490 return Err(ClientAuthorizationError::InvalidAssertion);
491 };
492
493 if let Some(client_id_from_form) = client_id_from_form {
494 if client_id != client_id_from_form {
496 return Err(ClientAuthorizationError::ClientIdMismatch {
497 credential: client_id,
498 form: client_id_from_form,
499 });
500 }
501 }
502
503 Credentials::ClientAssertionJwtBearer {
504 client_id,
505 jwt: Box::new(jwt),
506 }
507 }
508
509 (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
510 return Err(ClientAuthorizationError::UnsupportedClientAssertion {
512 client_assertion_type,
513 });
514 }
515
516 (Some(Authorization::Bearer(token)), None, None, None, None) => {
517 Credentials::BearerToken { token }
519 }
520
521 (None, None, None, None, None) => {
522 return Err(ClientAuthorizationError::MissingCredentials);
524 }
525
526 _ => {
527 return Err(ClientAuthorizationError::InvalidRequest);
529 }
530 };
531
532 Ok(ClientAuthorization { credentials, form })
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use axum::body::Body;
539 use http::{Method, Request};
540
541 use super::*;
542
543 #[tokio::test]
544 async fn none_test() {
545 let req = Request::builder()
546 .method(Method::POST)
547 .header(
548 http::header::CONTENT_TYPE,
549 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
550 )
551 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
552 .unwrap();
553
554 assert_eq!(
555 ClientAuthorization::<serde_json::Value>::from_request(req, &())
556 .await
557 .unwrap(),
558 ClientAuthorization {
559 credentials: Credentials::None {
560 client_id: "client-id".to_owned(),
561 },
562 form: Some(serde_json::json!({"foo": "bar"})),
563 }
564 );
565 }
566
567 #[tokio::test]
568 async fn client_secret_basic_test() {
569 let req = Request::builder()
570 .method(Method::POST)
571 .header(
572 http::header::CONTENT_TYPE,
573 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
574 )
575 .header(
576 http::header::AUTHORIZATION,
577 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
578 )
579 .body(Body::new("foo=bar".to_owned()))
580 .unwrap();
581
582 assert_eq!(
583 ClientAuthorization::<serde_json::Value>::from_request(req, &())
584 .await
585 .unwrap(),
586 ClientAuthorization {
587 credentials: Credentials::ClientSecretBasic {
588 client_id: "client-id".to_owned(),
589 client_secret: "client-secret".to_owned(),
590 },
591 form: Some(serde_json::json!({"foo": "bar"})),
592 }
593 );
594
595 let req = Request::builder()
597 .method(Method::POST)
598 .header(
599 http::header::CONTENT_TYPE,
600 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
601 )
602 .header(
603 http::header::AUTHORIZATION,
604 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
605 )
606 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
607 .unwrap();
608
609 assert_eq!(
610 ClientAuthorization::<serde_json::Value>::from_request(req, &())
611 .await
612 .unwrap(),
613 ClientAuthorization {
614 credentials: Credentials::ClientSecretBasic {
615 client_id: "client-id".to_owned(),
616 client_secret: "client-secret".to_owned(),
617 },
618 form: Some(serde_json::json!({"foo": "bar"})),
619 }
620 );
621
622 let req = Request::builder()
624 .method(Method::POST)
625 .header(
626 http::header::CONTENT_TYPE,
627 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
628 )
629 .header(
630 http::header::AUTHORIZATION,
631 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
632 )
633 .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
634 .unwrap();
635
636 assert!(matches!(
637 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
638 Err(ClientAuthorizationError::ClientIdMismatch { .. }),
639 ));
640
641 let req = Request::builder()
643 .method(Method::POST)
644 .header(
645 http::header::CONTENT_TYPE,
646 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
647 )
648 .header(http::header::AUTHORIZATION, "Basic invalid")
649 .body(Body::new("foo=bar".to_owned()))
650 .unwrap();
651
652 assert!(matches!(
653 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
654 Err(ClientAuthorizationError::InvalidHeader),
655 ));
656 }
657
658 #[tokio::test]
659 async fn client_secret_post_test() {
660 let req = Request::builder()
661 .method(Method::POST)
662 .header(
663 http::header::CONTENT_TYPE,
664 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
665 )
666 .body(Body::new(
667 "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
668 ))
669 .unwrap();
670
671 assert_eq!(
672 ClientAuthorization::<serde_json::Value>::from_request(req, &())
673 .await
674 .unwrap(),
675 ClientAuthorization {
676 credentials: Credentials::ClientSecretPost {
677 client_id: "client-id".to_owned(),
678 client_secret: "client-secret".to_owned(),
679 },
680 form: Some(serde_json::json!({"foo": "bar"})),
681 }
682 );
683 }
684
685 #[tokio::test]
686 async fn client_assertion_test() {
687 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
689 let body = Body::new(format!(
690 "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
691 ));
692
693 let req = Request::builder()
694 .method(Method::POST)
695 .header(
696 http::header::CONTENT_TYPE,
697 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
698 )
699 .body(body)
700 .unwrap();
701
702 let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
703 .await
704 .unwrap();
705 assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
706
707 let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
708 panic!("expected a JWT client_assertion");
709 };
710
711 assert_eq!(client_id, "client-id");
712 jwt.verify_with_shared_secret(b"client-secret".to_vec())
713 .unwrap();
714 }
715
716 #[tokio::test]
717 async fn bearer_token_test() {
718 let req = Request::builder()
719 .method(Method::POST)
720 .header(
721 http::header::CONTENT_TYPE,
722 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
723 )
724 .header(http::header::AUTHORIZATION, "Bearer token")
725 .body(Body::new("foo=bar".to_owned()))
726 .unwrap();
727
728 assert_eq!(
729 ClientAuthorization::<serde_json::Value>::from_request(req, &())
730 .await
731 .unwrap(),
732 ClientAuthorization {
733 credentials: Credentials::BearerToken {
734 token: "token".to_owned(),
735 },
736 form: Some(serde_json::json!({"foo": "bar"})),
737 }
738 );
739 }
740}