1use crate::https::extractors::ClientConnInfo;
2use crate::https::ServerState;
3use axum::{
4 body::Body,
5 extract::{connect_info::ConnectInfo, State},
6 http::{header::HeaderName, StatusCode},
7 http::{HeaderValue, Request},
8 middleware::Next,
9 response::{IntoResponse, Response},
10 RequestExt,
11};
12use kanidm_proto::constants::{KOPID, KVERSION, X_FORWARDED_FOR};
13use std::net::IpAddr;
14use uuid::Uuid;
1516#[allow(clippy::declare_interior_mutable_const)]
17const X_FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static(X_FORWARDED_FOR);
1819pub(crate) mod caching;
20pub(crate) mod compression;
21pub(crate) mod hsts_header;
22pub(crate) mod security_headers;
2324// the version middleware injects
25const KANIDM_VERSION: &str = env!("CARGO_PKG_VERSION");
2627/// Injects a header into the response with "X-KANIDM-VERSION" matching the version of the package.
28pub async fn version_middleware(request: Request<Body>, next: Next) -> Response {
29let mut response = next.run(request).await;
30 response
31 .headers_mut()
32 .insert(KVERSION, HeaderValue::from_static(KANIDM_VERSION));
33 response
34}
3536#[cfg(any(test, debug_assertions))]
37/// This is a debug middleware to ensure that /v1/ endpoints only return JSON
38#[instrument(level = "trace", name = "are_we_json_yet", skip_all)]
39pub async fn are_we_json_yet(request: Request<Body>, next: Next) -> Response {
40let uri = request.uri().path().to_string();
4142let response = next.run(request).await;
4344if uri.starts_with("/v1") && response.status().is_success() {
45let headers = response.headers();
46assert!(headers.contains_key(axum::http::header::CONTENT_TYPE));
47assert!(
48 headers.get(axum::http::header::CONTENT_TYPE)
49 == Some(&HeaderValue::from_static(
50 kanidm_proto::constants::APPLICATION_JSON
51 ))
52 );
53 }
5455 response
56}
5758#[derive(Clone, Debug)]
59/// For holding onto the event ID and other handy request-based things
60pub struct KOpId {
61/// The event correlation ID
62pub eventid: Uuid,
63}
6465/// This runs at the start of the request, adding an extension with `KOpId` which has useful things inside it.
66#[instrument(level = "trace", name = "kopid_middleware", skip_all)]
67pub async fn kopid_middleware(mut request: Request<Body>, next: Next) -> Response {
68// generate the event ID
69let eventid = sketching::tracing_forest::id();
7071// insert the extension so we can pull it out later
72request.extensions_mut().insert(KOpId { eventid });
73let mut response = next.run(request).await;
7475// This conversion *should never* fail. If it does, rather than panic, we warn and
76 // just don't put the id in the response.
77let _ = HeaderValue::from_str(&eventid.as_hyphenated().to_string())
78 .map(|hv| response.headers_mut().insert(KOPID, hv))
79 .map_err(|err| {
80warn!(?err, "An invalid operation id was encountered");
81 });
8283 response
84}
8586// This middleware extracts the ip_address and client information, and stores it
87// in the request extensions for future layers to use it.
88pub async fn ip_address_middleware(
89 State(state): State<ServerState>,
90mut request: Request<Body>,
91 next: Next,
92) -> Response {
93match ip_address_middleware_inner(&state, &mut request).await {
94Ok(trusted_client_ip) => {
95// By this point, proxy-v2 AND x-forward-for have resolved, so we can finally display this information.
96info!(connection_addr = %trusted_client_ip.connection_addr, client_ip_addr = %trusted_client_ip.client_ip_addr);
97 request.extensions_mut().insert(trusted_client_ip);
98 next.run(request).await
99}
100Err(err_status_and_reason) => err_status_and_reason.into_response(),
101 }
102}
103104async fn ip_address_middleware_inner(
105 state: &ServerState,
106 request: &mut Request<Body>,
107) -> Result<ClientConnInfo, (StatusCode, &'static str)> {
108// Extract the IP and insert it to the request.
109let ConnectInfo(ClientConnInfo {
110 connection_addr,
111 client_ip_addr,
112 client_cert,
113 }) = request
114 .extract_parts::<ConnectInfo<ClientConnInfo>>()
115 .await
116.map_err(|_| {
117error!("Connect info contains invalid data");
118 (
119 StatusCode::INTERNAL_SERVER_ERROR,
120"connect info contains invalid data",
121 )
122 })?;
123124let connection_ip_addr = connection_addr.ip();
125126let trust_x_forward_for = state
127 .trust_x_forward_for_ips
128 .as_ref()
129 .map(|range| range.contains(&connection_ip_addr))
130 .unwrap_or_default();
131132let client_ip_addr = if trust_x_forward_for {
133if let Some(x_forward_for) = request.headers().get(X_FORWARDED_FOR_HEADER) {
134// X forward for may be comma separated.
135let first = x_forward_for
136 .to_str()
137 .map(|s|
138// Split on an optional comma, return the first result.
139s.split(',').next().unwrap_or(s))
140 .map_err(|_| {
141 (
142 StatusCode::BAD_REQUEST,
143"X-Forwarded-For contains invalid data",
144 )
145 })?;
146147 first.parse::<IpAddr>().map_err(|_| {
148 (
149 StatusCode::BAD_REQUEST,
150"X-Forwarded-For contains invalid ip addr",
151 )
152 })?
153} else {
154 client_ip_addr
155 }
156 } else {
157// This can either be the client_addr == connection_addr if there are
158 // no ip address trust sources, or this is the value as reported by
159 // proxy protocol header. If the proxy protocol header is used, then
160 // trust_x_forward_for can never have been true so we catch here.
161client_ip_addr
162 };
163164Ok(ClientConnInfo {
165 connection_addr,
166 client_ip_addr,
167 client_cert,
168 })
169}