kanidmd_core/https/middleware/
mod.rs

1use crate::https::extractors::ClientConnInfo;
2use crate::https::ServerState;
3use axum::{
4    body::Body,
5    extract::{connect_info::ConnectInfo, State},
6    http::{header::HeaderName, StatusCode},
7    http::{HeaderValue, Request},
8    middleware::Next,
9    response::{IntoResponse, Response},
10    RequestExt,
11};
12use kanidm_proto::constants::{KOPID, KVERSION, X_FORWARDED_FOR};
13use std::net::IpAddr;
14use uuid::Uuid;
15
16#[allow(clippy::declare_interior_mutable_const)]
17const X_FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static(X_FORWARDED_FOR);
18
19pub(crate) mod caching;
20pub(crate) mod compression;
21pub(crate) mod hsts_header;
22pub(crate) mod security_headers;
23
24// the version middleware injects
25const KANIDM_VERSION: &str = env!("CARGO_PKG_VERSION");
26
27/// Injects a header into the response with "X-KANIDM-VERSION" matching the version of the package.
28pub async fn version_middleware(request: Request<Body>, next: Next) -> Response {
29    let mut response = next.run(request).await;
30    response
31        .headers_mut()
32        .insert(KVERSION, HeaderValue::from_static(KANIDM_VERSION));
33    response
34}
35
36#[cfg(any(test, debug_assertions))]
37/// This is a debug middleware to ensure that /v1/ endpoints only return JSON
38#[instrument(level = "trace", name = "are_we_json_yet", skip_all)]
39pub async fn are_we_json_yet(request: Request<Body>, next: Next) -> Response {
40    let uri = request.uri().path().to_string();
41
42    let response = next.run(request).await;
43
44    if uri.starts_with("/v1") && response.status().is_success() {
45        let headers = response.headers();
46        assert!(headers.contains_key(axum::http::header::CONTENT_TYPE));
47        assert!(
48            headers.get(axum::http::header::CONTENT_TYPE)
49                == Some(&HeaderValue::from_static(
50                    kanidm_proto::constants::APPLICATION_JSON
51                ))
52        );
53    }
54
55    response
56}
57
58#[derive(Clone, Debug)]
59/// For holding onto the event ID and other handy request-based things
60pub struct KOpId {
61    /// The event correlation ID
62    pub eventid: Uuid,
63}
64
65/// This runs at the start of the request, adding an extension with `KOpId` which has useful things inside it.
66#[instrument(level = "trace", name = "kopid_middleware", skip_all)]
67pub async fn kopid_middleware(mut request: Request<Body>, next: Next) -> Response {
68    // generate the event ID
69    let eventid = sketching::tracing_forest::id();
70
71    // insert the extension so we can pull it out later
72    request.extensions_mut().insert(KOpId { eventid });
73    let mut response = next.run(request).await;
74
75    // This conversion *should never* fail. If it does, rather than panic, we warn and
76    // just don't put the id in the response.
77    let _ = HeaderValue::from_str(&eventid.as_hyphenated().to_string())
78        .map(|hv| response.headers_mut().insert(KOPID, hv))
79        .map_err(|err| {
80            warn!(?err, "An invalid operation id was encountered");
81        });
82
83    response
84}
85
86// This middleware extracts the ip_address and client information, and stores it
87// in the request extensions for future layers to use it.
88pub async fn ip_address_middleware(
89    State(state): State<ServerState>,
90    mut request: Request<Body>,
91    next: Next,
92) -> Response {
93    match ip_address_middleware_inner(&state, &mut request).await {
94        Ok(trusted_client_ip) => {
95            // By this point, proxy-v2 AND x-forward-for have resolved, so we can finally display this information.
96            info!(connection_addr = %trusted_client_ip.connection_addr, client_ip_addr = %trusted_client_ip.client_ip_addr);
97            request.extensions_mut().insert(trusted_client_ip);
98            next.run(request).await
99        }
100        Err(err_status_and_reason) => err_status_and_reason.into_response(),
101    }
102}
103
104async fn ip_address_middleware_inner(
105    state: &ServerState,
106    request: &mut Request<Body>,
107) -> Result<ClientConnInfo, (StatusCode, &'static str)> {
108    // Extract the IP and insert it to the request.
109    let ConnectInfo(ClientConnInfo {
110        connection_addr,
111        client_ip_addr,
112        client_cert,
113    }) = request
114        .extract_parts::<ConnectInfo<ClientConnInfo>>()
115        .await
116        .map_err(|_| {
117            error!("Connect info contains invalid data");
118            (
119                StatusCode::INTERNAL_SERVER_ERROR,
120                "connect info contains invalid data",
121            )
122        })?;
123
124    // to_canonical maps linux ipv4 in ipv6 to an ipv4 addr.
125    let connection_ip_addr = connection_addr.ip().to_canonical();
126
127    let trust_x_forward_for = state
128        .trust_x_forward_for_ips
129        .as_ref()
130        .map(|range| range.contains(&connection_ip_addr))
131        .unwrap_or_default();
132
133    let client_ip_addr = if trust_x_forward_for {
134        if let Some(x_forward_for) = request.headers().get(X_FORWARDED_FOR_HEADER) {
135            // X forward for may be comma separated.
136            let first = x_forward_for
137                .to_str()
138                .map(|s|
139                    // Split on an optional comma, return the first result.
140                    s.split(',').next().unwrap_or(s))
141                .map_err(|_| {
142                    (
143                        StatusCode::BAD_REQUEST,
144                        "X-Forwarded-For contains invalid data",
145                    )
146                })?;
147
148            first.parse::<IpAddr>().map_err(|_| {
149                (
150                    StatusCode::BAD_REQUEST,
151                    "X-Forwarded-For contains invalid ip addr",
152                )
153            })?
154        } else {
155            client_ip_addr
156        }
157    } else {
158        // This can either be the client_addr == connection_addr if there are
159        // no ip address trust sources, or this is the value as reported by
160        // proxy protocol header. If the proxy protocol header is used, then
161        // trust_x_forward_for can never have been true so we catch here.
162        client_ip_addr
163    };
164
165    Ok(ClientConnInfo {
166        connection_addr,
167        client_ip_addr,
168        client_cert,
169    })
170}