kanidmd_core/https/middleware/
mod.rs

1use crate::https::extractors::ClientConnInfo;
2use crate::https::ServerState;
3use axum::{
4    body::Body,
5    extract::{connect_info::ConnectInfo, State},
6    http::{header::HeaderName, StatusCode},
7    http::{HeaderValue, Request},
8    middleware::Next,
9    response::{IntoResponse, Response},
10    RequestExt,
11};
12use kanidm_proto::constants::{KOPID, KVERSION, X_FORWARDED_FOR};
13use std::net::IpAddr;
14use uuid::Uuid;
15
16#[allow(clippy::declare_interior_mutable_const)]
17const X_FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static(X_FORWARDED_FOR);
18
19pub(crate) mod caching;
20pub(crate) mod compression;
21pub(crate) mod hsts_header;
22pub(crate) mod security_headers;
23
24// the version middleware injects
25const KANIDM_VERSION: &str = env!("CARGO_PKG_VERSION");
26
27/// Injects a header into the response with "X-KANIDM-VERSION" matching the version of the package.
28pub async fn version_middleware(request: Request<Body>, next: Next) -> Response {
29    let mut response = next.run(request).await;
30    response
31        .headers_mut()
32        .insert(KVERSION, HeaderValue::from_static(KANIDM_VERSION));
33    response
34}
35
36#[cfg(any(test, debug_assertions))]
37/// This is a debug middleware to ensure that /v1/ endpoints only return JSON
38#[instrument(level = "trace", name = "are_we_json_yet", skip_all)]
39pub async fn are_we_json_yet(request: Request<Body>, next: Next) -> Response {
40    let uri = request.uri().path().to_string();
41
42    let response = next.run(request).await;
43
44    if uri.starts_with("/v1") && response.status().is_success() {
45        let headers = response.headers();
46        assert!(headers.contains_key(axum::http::header::CONTENT_TYPE));
47        assert!(
48            headers.get(axum::http::header::CONTENT_TYPE)
49                == Some(&HeaderValue::from_static(
50                    kanidm_proto::constants::APPLICATION_JSON
51                ))
52        );
53    }
54
55    response
56}
57
58#[derive(Clone, Debug)]
59/// For holding onto the event ID and other handy request-based things
60pub struct KOpId {
61    /// The event correlation ID
62    pub eventid: Uuid,
63}
64
65/// This runs at the start of the request, adding an extension with `KOpId` which has useful things inside it.
66#[instrument(level = "trace", name = "kopid_middleware", skip_all)]
67pub async fn kopid_middleware(mut request: Request<Body>, next: Next) -> Response {
68    // generate the event ID
69    let eventid = sketching::tracing_forest::id();
70
71    // insert the extension so we can pull it out later
72    request.extensions_mut().insert(KOpId { eventid });
73    let mut response = next.run(request).await;
74
75    // This conversion *should never* fail. If it does, rather than panic, we warn and
76    // just don't put the id in the response.
77    let _ = HeaderValue::from_str(&eventid.as_hyphenated().to_string())
78        .map(|hv| response.headers_mut().insert(KOPID, hv))
79        .map_err(|err| {
80            warn!(?err, "An invalid operation id was encountered");
81        });
82
83    response
84}
85
86// This middleware extracts the ip_address and client information, and stores it
87// in the request extensions for future layers to use it.
88pub async fn ip_address_middleware(
89    State(state): State<ServerState>,
90    mut request: Request<Body>,
91    next: Next,
92) -> Response {
93    match ip_address_middleware_inner(&state, &mut request).await {
94        Ok(trusted_client_ip) => {
95            // By this point, proxy-v2 AND x-forward-for have resolved, so we can finally display this information.
96            info!(connection_addr = %trusted_client_ip.connection_addr, client_ip_addr = %trusted_client_ip.client_ip_addr);
97            request.extensions_mut().insert(trusted_client_ip);
98            next.run(request).await
99        }
100        Err(err_status_and_reason) => err_status_and_reason.into_response(),
101    }
102}
103
104async fn ip_address_middleware_inner(
105    state: &ServerState,
106    request: &mut Request<Body>,
107) -> Result<ClientConnInfo, (StatusCode, &'static str)> {
108    // Extract the IP and insert it to the request.
109    let ConnectInfo(ClientConnInfo {
110        connection_addr,
111        client_ip_addr,
112        client_cert,
113    }) = request
114        .extract_parts::<ConnectInfo<ClientConnInfo>>()
115        .await
116        .map_err(|_| {
117            error!("Connect info contains invalid data");
118            (
119                StatusCode::INTERNAL_SERVER_ERROR,
120                "connect info contains invalid data",
121            )
122        })?;
123
124    let connection_ip_addr = connection_addr.ip();
125
126    let trust_x_forward_for = state
127        .trust_x_forward_for_ips
128        .as_ref()
129        .map(|range| range.contains(&connection_ip_addr))
130        .unwrap_or_default();
131
132    let client_ip_addr = if trust_x_forward_for {
133        if let Some(x_forward_for) = request.headers().get(X_FORWARDED_FOR_HEADER) {
134            // X forward for may be comma separated.
135            let first = x_forward_for
136                .to_str()
137                .map(|s|
138                    // Split on an optional comma, return the first result.
139                    s.split(',').next().unwrap_or(s))
140                .map_err(|_| {
141                    (
142                        StatusCode::BAD_REQUEST,
143                        "X-Forwarded-For contains invalid data",
144                    )
145                })?;
146
147            first.parse::<IpAddr>().map_err(|_| {
148                (
149                    StatusCode::BAD_REQUEST,
150                    "X-Forwarded-For contains invalid ip addr",
151                )
152            })?
153        } else {
154            client_ip_addr
155        }
156    } else {
157        // This can either be the client_addr == connection_addr if there are
158        // no ip address trust sources, or this is the value as reported by
159        // proxy protocol header. If the proxy protocol header is used, then
160        // trust_x_forward_for can never have been true so we catch here.
161        client_ip_addr
162    };
163
164    Ok(ClientConnInfo {
165        connection_addr,
166        client_ip_addr,
167        client_cert,
168    })
169}