kanidmd_core/https/middleware/
mod.rs1use crate::https::extractors::ClientConnInfo;
2use crate::https::ServerState;
3use axum::{
4 body::Body,
5 extract::{connect_info::ConnectInfo, State},
6 http::{header::HeaderName, StatusCode},
7 http::{HeaderValue, Request},
8 middleware::Next,
9 response::{IntoResponse, Response},
10 RequestExt,
11};
12use kanidm_proto::constants::{KOPID, KVERSION, X_FORWARDED_FOR};
13use std::net::IpAddr;
14use uuid::Uuid;
15
16#[allow(clippy::declare_interior_mutable_const)]
17const X_FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static(X_FORWARDED_FOR);
18
19pub(crate) mod caching;
20pub(crate) mod compression;
21pub(crate) mod hsts_header;
22pub(crate) mod security_headers;
23
24const KANIDM_VERSION: &str = env!("CARGO_PKG_VERSION");
26
27pub async fn version_middleware(request: Request<Body>, next: Next) -> Response {
29 let mut response = next.run(request).await;
30 response
31 .headers_mut()
32 .insert(KVERSION, HeaderValue::from_static(KANIDM_VERSION));
33 response
34}
35
36#[cfg(any(test, debug_assertions))]
37#[instrument(level = "trace", name = "are_we_json_yet", skip_all)]
39pub async fn are_we_json_yet(request: Request<Body>, next: Next) -> Response {
40 let uri = request.uri().path().to_string();
41
42 let response = next.run(request).await;
43
44 if uri.starts_with("/v1") && response.status().is_success() {
45 let headers = response.headers();
46 assert!(headers.contains_key(axum::http::header::CONTENT_TYPE));
47 assert!(
48 headers.get(axum::http::header::CONTENT_TYPE)
49 == Some(&HeaderValue::from_static(
50 kanidm_proto::constants::APPLICATION_JSON
51 ))
52 );
53 }
54
55 response
56}
57
58#[derive(Clone, Debug)]
59pub struct KOpId {
61 pub eventid: Uuid,
63}
64
65#[instrument(level = "trace", name = "kopid_middleware", skip_all)]
67pub async fn kopid_middleware(mut request: Request<Body>, next: Next) -> Response {
68 let eventid = sketching::tracing_forest::id();
70
71 request.extensions_mut().insert(KOpId { eventid });
73 let mut response = next.run(request).await;
74
75 let _ = HeaderValue::from_str(&eventid.as_hyphenated().to_string())
78 .map(|hv| response.headers_mut().insert(KOPID, hv))
79 .map_err(|err| {
80 warn!(?err, "An invalid operation id was encountered");
81 });
82
83 response
84}
85
86pub async fn ip_address_middleware(
89 State(state): State<ServerState>,
90 mut request: Request<Body>,
91 next: Next,
92) -> Response {
93 match ip_address_middleware_inner(&state, &mut request).await {
94 Ok(trusted_client_ip) => {
95 info!(connection_addr = %trusted_client_ip.connection_addr, client_ip_addr = %trusted_client_ip.client_ip_addr);
97 request.extensions_mut().insert(trusted_client_ip);
98 next.run(request).await
99 }
100 Err(err_status_and_reason) => err_status_and_reason.into_response(),
101 }
102}
103
104async fn ip_address_middleware_inner(
105 state: &ServerState,
106 request: &mut Request<Body>,
107) -> Result<ClientConnInfo, (StatusCode, &'static str)> {
108 let ConnectInfo(ClientConnInfo {
110 connection_addr,
111 client_ip_addr,
112 client_cert,
113 }) = request
114 .extract_parts::<ConnectInfo<ClientConnInfo>>()
115 .await
116 .map_err(|_| {
117 error!("Connect info contains invalid data");
118 (
119 StatusCode::INTERNAL_SERVER_ERROR,
120 "connect info contains invalid data",
121 )
122 })?;
123
124 let connection_ip_addr = connection_addr.ip();
125
126 let trust_x_forward_for = state
127 .trust_x_forward_for_ips
128 .as_ref()
129 .map(|range| range.contains(&connection_ip_addr))
130 .unwrap_or_default();
131
132 let client_ip_addr = if trust_x_forward_for {
133 if let Some(x_forward_for) = request.headers().get(X_FORWARDED_FOR_HEADER) {
134 let first = x_forward_for
136 .to_str()
137 .map(|s|
138 s.split(',').next().unwrap_or(s))
140 .map_err(|_| {
141 (
142 StatusCode::BAD_REQUEST,
143 "X-Forwarded-For contains invalid data",
144 )
145 })?;
146
147 first.parse::<IpAddr>().map_err(|_| {
148 (
149 StatusCode::BAD_REQUEST,
150 "X-Forwarded-For contains invalid ip addr",
151 )
152 })?
153 } else {
154 client_ip_addr
155 }
156 } else {
157 client_ip_addr
162 };
163
164 Ok(ClientConnInfo {
165 connection_addr,
166 client_ip_addr,
167 client_cert,
168 })
169}