use std::collections::{BTreeMap, BTreeSet};
use super::errors::WebError;
use super::middleware::KOpId;
use super::ServerState;
use crate::https::extractors::VerifiedClientInformation;
use axum::{
body::Body,
extract::{Path, Query, State},
http::{
header::{
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE, LOCATION,
WWW_AUTHENTICATE,
},
HeaderValue, StatusCode,
},
middleware::from_fn,
response::{IntoResponse, Response},
routing::{get, post},
Extension, Form, Json, Router,
};
use axum_macros::debug_handler;
use kanidm_proto::constants::uri::{
OAUTH2_AUTHORISE, OAUTH2_AUTHORISE_PERMIT, OAUTH2_AUTHORISE_REJECT,
};
use kanidm_proto::constants::APPLICATION_JSON;
use kanidm_proto::oauth2::AuthorisationResponse;
#[cfg(feature = "dev-oauth2-device-flow")]
use kanidm_proto::oauth2::DeviceAuthorizationResponse;
use kanidmd_lib::idm::oauth2::{
AccessTokenIntrospectRequest, AccessTokenRequest, AuthorisationRequest, AuthorisePermitSuccess,
AuthoriseResponse, ErrorResponse, Oauth2Error, TokenRevokeRequest,
};
use kanidmd_lib::prelude::f_eq;
use kanidmd_lib::prelude::*;
use kanidmd_lib::value::PartialValue;
use serde::{Deserialize, Serialize};
use serde_with::formats::CommaSeparator;
use serde_with::{serde_as, StringWithSeparator};
#[cfg(feature = "dev-oauth2-device-flow")]
use uri::OAUTH2_AUTHORISE_DEVICE;
use uri::{OAUTH2_TOKEN_ENDPOINT, OAUTH2_TOKEN_INTROSPECT_ENDPOINT, OAUTH2_TOKEN_REVOKE_ENDPOINT};
pub struct HTTPOauth2Error(Oauth2Error);
impl IntoResponse for HTTPOauth2Error {
fn into_response(self) -> Response {
let HTTPOauth2Error(error) = self;
if let Oauth2Error::AuthenticationRequired = error {
(
StatusCode::UNAUTHORIZED,
[
(WWW_AUTHENTICATE, "Bearer"),
(ACCESS_CONTROL_ALLOW_ORIGIN, "*"),
],
)
.into_response()
} else {
let err = ErrorResponse {
error: error.to_string(),
..Default::default()
};
let body = match serde_json::to_string(&err) {
Ok(val) => val,
Err(e) => {
admin_warn!("Failed to serialize error response: original_error=\"{:?}\" serialization_error=\"{:?}\"", err, e);
format!("{:?}", err)
}
};
(
StatusCode::BAD_REQUEST,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
body,
)
.into_response()
}
}
}
pub(crate) fn oauth2_id(rs_name: &str) -> Filter<FilterInvalid> {
filter_all!(f_and!([
f_eq(Attribute::Class, EntryClass::OAuth2ResourceServer.into()),
f_eq(Attribute::Name, PartialValue::new_iname(rs_name))
]))
}
#[utoipa::path(
get,
path = "/ui/images/oauth2/{rs_name}",
operation_id = "oauth2_image_get",
responses(
(status = 200, description = "Ok", body=&[u8]),
(status = 401, description = "Authorization required"),
(status = 403, description = "Not Authorized"),
),
security(("token_jwt" = [])),
tag = "ui",
)]
pub(crate) async fn oauth2_image_get(
State(state): State<ServerState>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Path(rs_name): Path<String>,
) -> Response {
let rs_filter = oauth2_id(&rs_name);
let res = state
.qe_r_ref
.handle_oauth2_rs_image_get_image(client_auth_info, rs_filter)
.await;
match res {
Ok(Some(image)) => (
StatusCode::OK,
[(CONTENT_TYPE, image.filetype.as_content_type_str())],
image.contents,
)
.into_response(),
Ok(None) => {
warn!(?rs_name, "No image set for OAuth2 client");
(StatusCode::NOT_FOUND, "").into_response()
}
Err(err) => WebError::from(err).into_response(),
}
}
#[instrument(level = "debug", skip(state, kopid))]
pub async fn oauth2_authorise_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Json(auth_req): Json<AuthorisationRequest>,
) -> impl IntoResponse {
let mut res = oauth2_authorise(state, auth_req, kopid, client_auth_info)
.await
.into_response();
if res.status() == StatusCode::FOUND {
*res.status_mut() = StatusCode::OK;
}
res
}
#[instrument(level = "debug", skip(state, kopid))]
pub async fn oauth2_authorise_get(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Query(auth_req): Query<AuthorisationRequest>,
) -> impl IntoResponse {
oauth2_authorise(state, auth_req, kopid, client_auth_info).await
}
async fn oauth2_authorise(
state: ServerState,
auth_req: AuthorisationRequest,
kopid: KOpId,
client_auth_info: ClientAuthInfo,
) -> impl IntoResponse {
let res: Result<AuthoriseResponse, Oauth2Error> = state
.qe_r_ref
.handle_oauth2_authorise(client_auth_info, auth_req, kopid.eventid)
.await;
match res {
Ok(AuthoriseResponse::ConsentRequested {
client_name,
scopes,
pii_scopes,
consent_token,
}) => {
#[allow(clippy::unwrap_used)]
let body = serde_json::to_string(&AuthorisationResponse::ConsentRequested {
client_name,
scopes,
pii_scopes,
consent_token,
})
.unwrap();
#[allow(clippy::unwrap_used)]
Response::builder()
.status(StatusCode::OK)
.body(body.into())
.unwrap()
}
Ok(AuthoriseResponse::Permitted(AuthorisePermitSuccess {
mut redirect_uri,
state,
code,
})) => {
#[allow(clippy::unwrap_used)]
let body =
Body::from(serde_json::to_string(&AuthorisationResponse::Permitted).unwrap());
redirect_uri
.query_pairs_mut()
.clear()
.append_pair("state", &state)
.append_pair("code", &code);
#[allow(clippy::unwrap_used)]
Response::builder()
.status(StatusCode::FOUND)
.header(
LOCATION,
HeaderValue::from_str(redirect_uri.as_str()).unwrap(),
)
.header(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(&redirect_uri.origin().ascii_serialization()).unwrap(),
)
.body(body)
.unwrap()
}
Ok(AuthoriseResponse::AuthenticationRequired { .. })
| Err(Oauth2Error::AuthenticationRequired) => {
#[allow(clippy::unwrap_used)]
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(WWW_AUTHENTICATE, HeaderValue::from_static("Bearer"))
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(Body::empty())
.unwrap()
}
Err(Oauth2Error::AccessDenied) => {
#[allow(clippy::expect_used)]
Response::builder()
.status(StatusCode::FORBIDDEN)
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(Body::empty())
.expect("Failed to generate a forbidden response")
}
Err(e) => {
admin_error!(
"Unable to authorise - Error ID: {:?} error: {}",
kopid.eventid,
&e.to_string()
);
#[allow(clippy::expect_used)]
Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(Body::empty())
.expect("Failed to generate a bad request response")
}
}
}
pub async fn oauth2_authorise_permit_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Json(consent_req): Json<String>,
) -> impl IntoResponse {
let mut res = oauth2_authorise_permit(state, consent_req, kopid, client_auth_info)
.await
.into_response();
if res.status() == StatusCode::FOUND {
*res.status_mut() = StatusCode::OK;
}
res
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ConsentRequestData {
token: String,
}
pub async fn oauth2_authorise_permit_get(
State(state): State<ServerState>,
Query(token): Query<ConsentRequestData>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
) -> impl IntoResponse {
oauth2_authorise_permit(state, token.token, kopid, client_auth_info).await
}
async fn oauth2_authorise_permit(
state: ServerState,
consent_req: String,
kopid: KOpId,
client_auth_info: ClientAuthInfo,
) -> impl IntoResponse {
let res = state
.qe_w_ref
.handle_oauth2_authorise_permit(client_auth_info, consent_req, kopid.eventid)
.await;
match res {
Ok(AuthorisePermitSuccess {
mut redirect_uri,
state,
code,
}) => {
redirect_uri
.query_pairs_mut()
.clear()
.append_pair("state", &state)
.append_pair("code", &code);
#[allow(clippy::expect_used)]
Response::builder()
.status(StatusCode::FOUND)
.header(LOCATION, redirect_uri.as_str())
.header(
ACCESS_CONTROL_ALLOW_ORIGIN,
redirect_uri.origin().ascii_serialization(),
)
.body(Body::empty())
.expect("Failed to generate response")
}
Err(err) => {
match err {
OperationError::NotAuthenticated => {
WebError::from(err).response_with_access_control_origin_header()
}
_ => {
#[allow(clippy::expect_used)]
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(Body::empty())
.expect("Failed to generate error response")
}
}
}
}
}
pub async fn oauth2_authorise_reject_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Form(consent_req): Form<ConsentRequestData>,
) -> Response<Body> {
oauth2_authorise_reject(state, consent_req.token, kopid, client_auth_info).await
}
pub async fn oauth2_authorise_reject_get(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Query(consent_req): Query<ConsentRequestData>,
) -> Response<Body> {
oauth2_authorise_reject(state, consent_req.token, kopid, client_auth_info).await
}
async fn oauth2_authorise_reject(
state: ServerState,
consent_req: String,
kopid: KOpId,
client_auth_info: ClientAuthInfo,
) -> Response<Body> {
let res = state
.qe_r_ref
.handle_oauth2_authorise_reject(client_auth_info, consent_req, kopid.eventid)
.await;
match res {
Ok(mut redirect_uri) => {
redirect_uri
.query_pairs_mut()
.clear()
.append_pair("error", "access_denied")
.append_pair("error_description", "authorisation rejected");
#[allow(clippy::unwrap_used)]
Response::builder()
.header(LOCATION, redirect_uri.as_str())
.header(
ACCESS_CONTROL_ALLOW_ORIGIN,
redirect_uri.origin().ascii_serialization(),
)
.body(Body::empty())
.unwrap()
}
Err(err) => {
match err {
OperationError::NotAuthenticated => {
WebError::from(err).response_with_access_control_origin_header()
}
_ => {
#[allow(clippy::expect_used)]
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(Body::empty())
.expect("Failed to generate an error response")
}
}
}
}
}
#[axum_macros::debug_handler]
#[instrument(skip(state, kopid, client_auth_info), level = "DEBUG")]
pub async fn oauth2_token_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Form(tok_req): Form<AccessTokenRequest>,
) -> impl IntoResponse {
match state
.qe_w_ref
.handle_oauth2_token_exchange(client_auth_info, tok_req, kopid.eventid)
.await
{
Ok(tok_res) => (
StatusCode::OK,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
Json(tok_res),
)
.into_response(),
Err(e) => HTTPOauth2Error(e).into_response(),
}
}
pub async fn oauth2_openid_discovery_get(
State(state): State<ServerState>,
Path(client_id): Path<String>,
Extension(kopid): Extension<KOpId>,
) -> impl IntoResponse {
let res = state
.qe_r_ref
.handle_oauth2_openid_discovery(client_id, kopid.eventid)
.await;
match res {
Ok(dsc) => (
StatusCode::OK,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
Json(dsc),
)
.into_response(),
Err(e) => {
error!(err = ?e, "Unable to access discovery info");
WebError::from(e).response_with_access_control_origin_header()
}
}
}
pub async fn oauth2_rfc8414_metadata_get(
State(state): State<ServerState>,
Path(client_id): Path<String>,
Extension(kopid): Extension<KOpId>,
) -> impl IntoResponse {
let res = state
.qe_r_ref
.handle_oauth2_rfc8414_metadata(client_id, kopid.eventid)
.await;
match res {
Ok(dsc) => (
StatusCode::OK,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
Json(dsc),
)
.into_response(),
Err(e) => {
error!(err = ?e, "Unable to access discovery info");
WebError::from(e).response_with_access_control_origin_header()
}
}
}
#[debug_handler]
pub async fn oauth2_openid_userinfo_get(
State(state): State<ServerState>,
Path(client_id): Path<String>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
) -> Response {
let client_token = match client_auth_info.bearer_token {
Some(val) => val,
None => {
error!("Bearer Authentication Not Provided");
return HTTPOauth2Error(Oauth2Error::AuthenticationRequired).into_response();
}
};
let res = state
.qe_r_ref
.handle_oauth2_openid_userinfo(client_id, client_token, kopid.eventid)
.await;
match res {
Ok(uir) => (
StatusCode::OK,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
Json(uir),
)
.into_response(),
Err(e) => HTTPOauth2Error(e).into_response(),
}
}
pub async fn oauth2_openid_publickey_get(
State(state): State<ServerState>,
Path(client_id): Path<String>,
Extension(kopid): Extension<KOpId>,
) -> Response {
let res = state
.qe_r_ref
.handle_oauth2_openid_publickey(client_id, kopid.eventid)
.await
.map(Json::from)
.map_err(WebError::from);
match res {
Ok(jsn) => (StatusCode::OK, [(ACCESS_CONTROL_ALLOW_ORIGIN, "*")], jsn).into_response(),
Err(web_err) => web_err.response_with_access_control_origin_header(),
}
}
pub async fn oauth2_token_introspect_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Form(intr_req): Form<AccessTokenIntrospectRequest>,
) -> impl IntoResponse {
request_trace!("Introspect Request - {:?}", intr_req);
let res = state
.qe_r_ref
.handle_oauth2_token_introspect(client_auth_info, intr_req, kopid.eventid)
.await;
match res {
Ok(atr) => {
let body = match serde_json::to_string(&atr) {
Ok(val) => val,
Err(e) => {
admin_warn!("Failed to serialize introspect response: original_data=\"{:?}\" serialization_error=\"{:?}\"", atr, e);
format!("{:?}", atr)
}
};
#[allow(clippy::unwrap_used)]
Response::builder()
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.header(CONTENT_TYPE, APPLICATION_JSON)
.body(Body::from(body))
.unwrap()
}
Err(Oauth2Error::AuthenticationRequired) => {
#[allow(clippy::expect_used)]
Response::builder()
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.expect("Failed to generate an unauthorized response")
}
Err(e) => {
let err = ErrorResponse {
error: e.to_string(),
..Default::default()
};
let body = match serde_json::to_string(&err) {
Ok(val) => val,
Err(e) => {
format!("{:?}", e)
}
};
#[allow(clippy::expect_used)]
Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.body(Body::from(body))
.expect("Failed to generate an error response")
}
}
}
pub async fn oauth2_token_revoke_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Form(intr_req): Form<TokenRevokeRequest>,
) -> impl IntoResponse {
request_trace!("Revoke Request - {:?}", intr_req);
let res = state
.qe_w_ref
.handle_oauth2_token_revoke(client_auth_info, intr_req, kopid.eventid)
.await;
match res {
Ok(()) => (StatusCode::OK, [(ACCESS_CONTROL_ALLOW_ORIGIN, "*")], "").into_response(),
Err(Oauth2Error::AuthenticationRequired) => {
(
StatusCode::UNAUTHORIZED,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
"",
)
.into_response()
}
Err(e) => {
let err = ErrorResponse {
error: e.to_string(),
..Default::default()
};
(
StatusCode::BAD_REQUEST,
[(ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
serde_json::to_string(&err).unwrap_or("".to_string()),
)
.into_response()
}
}
}
pub async fn oauth2_preflight_options() -> Response {
(
StatusCode::OK,
[
(ACCESS_CONTROL_ALLOW_ORIGIN, "*"),
(ACCESS_CONTROL_ALLOW_HEADERS, "Authorization"),
],
String::new(),
)
.into_response()
}
#[serde_as]
#[derive(Deserialize, Debug, Serialize)]
pub(crate) struct DeviceFlowForm {
client_id: String,
#[serde_as(as = "Option<StringWithSeparator::<CommaSeparator, String>>")]
scope: Option<BTreeSet<String>>,
#[serde(flatten)]
extra: BTreeMap<String, String>, }
#[cfg(feature = "dev-oauth2-device-flow")]
#[instrument(level = "info", skip(state, kopid, client_auth_info))]
pub(crate) async fn oauth2_authorise_device_post(
State(state): State<ServerState>,
Extension(kopid): Extension<KOpId>,
VerifiedClientInformation(client_auth_info): VerifiedClientInformation,
Form(form): Form<DeviceFlowForm>,
) -> Result<Json<DeviceAuthorizationResponse>, HTTPOauth2Error> {
state
.qe_w_ref
.handle_oauth2_device_flow_start(
client_auth_info,
&form.client_id,
&form.scope,
kopid.eventid,
)
.await
.map(Json::from)
.map_err(HTTPOauth2Error)
}
pub fn route_setup(state: ServerState) -> Router<ServerState> {
let openid_router = Router::new()
.route(
"/oauth2/openid/:client_id/.well-known/openid-configuration",
get(oauth2_openid_discovery_get).options(oauth2_preflight_options),
)
.route(
"/oauth2/openid/:client_id/userinfo",
get(oauth2_openid_userinfo_get).options(oauth2_preflight_options),
)
.route(
"/oauth2/openid/:client_id/public_key.jwk",
get(oauth2_openid_publickey_get).options(oauth2_preflight_options),
)
.route(
"/oauth2/openid/:client_id/.well-known/oauth-authorization-server",
get(oauth2_rfc8414_metadata_get).options(oauth2_preflight_options),
)
.with_state(state.clone());
let mut router = Router::new()
.route("/oauth2", get(super::v1_oauth2::oauth2_get))
.route(
OAUTH2_AUTHORISE,
post(oauth2_authorise_post).get(oauth2_authorise_get),
)
.route(
OAUTH2_AUTHORISE_PERMIT,
post(oauth2_authorise_permit_post).get(oauth2_authorise_permit_get),
)
.route(
OAUTH2_AUTHORISE_REJECT,
post(oauth2_authorise_reject_post).get(oauth2_authorise_reject_get),
);
#[cfg(feature = "dev-oauth2-device-flow")]
{
router = router.route(OAUTH2_AUTHORISE_DEVICE, post(oauth2_authorise_device_post))
}
router = router
.route(
OAUTH2_TOKEN_ENDPOINT,
post(oauth2_token_post).options(oauth2_preflight_options),
)
.route(
OAUTH2_TOKEN_INTROSPECT_ENDPOINT,
post(oauth2_token_introspect_post),
)
.route(OAUTH2_TOKEN_REVOKE_ENDPOINT, post(oauth2_token_revoke_post))
.merge(openid_router)
.with_state(state)
.layer(from_fn(super::middleware::caching::dont_cache_me));
router
}