kanidmd_core/https/extractors/
mod.rs

1use axum::{
2    async_trait,
3    extract::connect_info::{ConnectInfo, Connected},
4    extract::FromRequestParts,
5    http::{
6        header::HeaderName, header::AUTHORIZATION as AUTHORISATION, request::Parts, StatusCode,
7    },
8    RequestPartsExt,
9};
10
11use axum_extra::extract::cookie::CookieJar;
12
13use kanidm_proto::constants::X_FORWARDED_FOR;
14use kanidm_proto::internal::COOKIE_BEARER_TOKEN;
15use kanidmd_lib::prelude::{ClientAuthInfo, ClientCertInfo, Source};
16// Re-export
17pub use kanidmd_lib::idm::server::DomainInfoRead;
18
19use compact_jwt::JwsCompact;
20use std::str::FromStr;
21
22use std::net::{IpAddr, SocketAddr};
23
24use crate::https::ServerState;
25
26#[allow(clippy::declare_interior_mutable_const)]
27const X_FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static(X_FORWARDED_FOR);
28
29pub struct TrustedClientIp(pub IpAddr);
30
31#[async_trait]
32impl FromRequestParts<ServerState> for TrustedClientIp {
33    type Rejection = (StatusCode, &'static str);
34
35    // Need to skip all to prevent leaking tokens to logs.
36    #[instrument(level = "debug", skip_all)]
37    async fn from_request_parts(
38        parts: &mut Parts,
39        state: &ServerState,
40    ) -> Result<Self, Self::Rejection> {
41        let ConnectInfo(ClientConnInfo {
42            connection_addr,
43            client_addr,
44            client_cert: _,
45        }) = parts
46            .extract::<ConnectInfo<ClientConnInfo>>()
47            .await
48            .map_err(|_| {
49                error!("Connect info contains invalid data");
50                (
51                    StatusCode::BAD_REQUEST,
52                    "connect info contains invalid data",
53                )
54            })?;
55
56        let trust_x_forward_for = state
57            .trust_x_forward_for_ips
58            .as_ref()
59            .map(|range| range.contains(&connection_addr.ip()))
60            .unwrap_or_default();
61
62        let ip_addr = if trust_x_forward_for {
63            if let Some(x_forward_for) = parts.headers.get(X_FORWARDED_FOR_HEADER) {
64                // X forward for may be comma separated.
65                let first = x_forward_for
66                    .to_str()
67                    .map(|s|
68                        // Split on an optional comma, return the first result.
69                        s.split(',').next().unwrap_or(s))
70                    .map_err(|_| {
71                        (
72                            StatusCode::BAD_REQUEST,
73                            "X-Forwarded-For contains invalid data",
74                        )
75                    })?;
76
77                first.parse::<IpAddr>().map_err(|_| {
78                    (
79                        StatusCode::BAD_REQUEST,
80                        "X-Forwarded-For contains invalid ip addr",
81                    )
82                })?
83            } else {
84                client_addr.ip()
85            }
86        } else {
87            // This can either be the client_addr == connection_addr if there are
88            // no ip address trust sources, or this is the value as reported by
89            // proxy protocol header. If the proxy protocol header is used, then
90            // trust_x_forward_for can never have been true so we catch here.
91            client_addr.ip()
92        };
93
94        Ok(TrustedClientIp(ip_addr))
95    }
96}
97
98pub struct VerifiedClientInformation(pub ClientAuthInfo);
99
100#[async_trait]
101impl FromRequestParts<ServerState> for VerifiedClientInformation {
102    type Rejection = (StatusCode, &'static str);
103
104    // Need to skip all to prevent leaking tokens to logs.
105    #[instrument(level = "debug", skip_all)]
106    async fn from_request_parts(
107        parts: &mut Parts,
108        state: &ServerState,
109    ) -> Result<Self, Self::Rejection> {
110        let ConnectInfo(ClientConnInfo {
111            connection_addr,
112            client_addr,
113            client_cert,
114        }) = parts
115            .extract::<ConnectInfo<ClientConnInfo>>()
116            .await
117            .map_err(|_| {
118                error!("Connect info contains invalid data");
119                (
120                    StatusCode::BAD_REQUEST,
121                    "connect info contains invalid data",
122                )
123            })?;
124
125        let trust_x_forward_for = state
126            .trust_x_forward_for_ips
127            .as_ref()
128            .map(|range| range.contains(&connection_addr.ip()))
129            .unwrap_or_default();
130
131        let ip_addr = if trust_x_forward_for {
132            if let Some(x_forward_for) = parts.headers.get(X_FORWARDED_FOR_HEADER) {
133                // X forward for may be comma separated.
134                let first = x_forward_for
135                    .to_str()
136                    .map(|s|
137                        // Split on an optional comma, return the first result.
138                        s.split(',').next().unwrap_or(s))
139                    .map_err(|_| {
140                        (
141                            StatusCode::BAD_REQUEST,
142                            "X-Forwarded-For contains invalid data",
143                        )
144                    })?;
145
146                first.parse::<IpAddr>().map_err(|_| {
147                    (
148                        StatusCode::BAD_REQUEST,
149                        "X-Forwarded-For contains invalid ip addr",
150                    )
151                })?
152            } else {
153                client_addr.ip()
154            }
155        } else {
156            client_addr.ip()
157        };
158
159        let (basic_authz, bearer_token) = if let Some(header) = parts.headers.get(AUTHORISATION) {
160            if let Some((authz_type, authz_data)) = header
161                .to_str()
162                .map_err(|err| {
163                    warn!(?err, "Invalid authz header, ignoring");
164                })
165                .ok()
166                .and_then(|s| s.split_once(' '))
167            {
168                let authz_type = authz_type.to_lowercase();
169
170                if authz_type == "basic" {
171                    (Some(authz_data.to_string()), None)
172                } else if authz_type == "bearer" {
173                    if let Ok(jwsc) = JwsCompact::from_str(authz_data) {
174                        (None, Some(jwsc))
175                    } else {
176                        warn!("bearer jws invalid");
177                        (None, None)
178                    }
179                } else {
180                    warn!("authorisation header invalid, ignoring");
181                    (None, None)
182                }
183            } else {
184                (None, None)
185            }
186        } else {
187            // Only if there are no credentials in bearer, do we examine cookies.
188            let jar = CookieJar::from_headers(&parts.headers);
189
190            let value: Option<&str> = jar.get(COOKIE_BEARER_TOKEN).map(|c| c.value());
191
192            let maybe_bearer = value.and_then(|authz_data| JwsCompact::from_str(authz_data).ok());
193
194            (None, maybe_bearer)
195        };
196
197        Ok(VerifiedClientInformation(ClientAuthInfo {
198            source: Source::Https(ip_addr),
199            bearer_token,
200            basic_authz,
201            client_cert,
202        }))
203    }
204}
205
206pub struct DomainInfo(pub DomainInfoRead);
207
208#[async_trait]
209impl FromRequestParts<ServerState> for DomainInfo {
210    type Rejection = (StatusCode, &'static str);
211
212    // Need to skip all to prevent leaking tokens to logs.
213    #[instrument(level = "debug", skip_all)]
214    async fn from_request_parts(
215        _parts: &mut Parts,
216        state: &ServerState,
217    ) -> Result<Self, Self::Rejection> {
218        Ok(DomainInfo(state.qe_r_ref.domain_info_read()))
219    }
220}
221
222#[derive(Debug, Clone)]
223pub struct ClientConnInfo {
224    /// This is the address that is *connected* to Kanidm right now
225    /// for this operation.
226    #[allow(dead_code)]
227    pub connection_addr: SocketAddr,
228    /// This is the client address as reported by a remote IP source
229    /// such as x-forward-for or the PROXY protocol header
230    pub client_addr: SocketAddr,
231    // Only set if the certificate is VALID
232    pub client_cert: Option<ClientCertInfo>,
233}
234
235// This is the normal way that our extractors get the ip info
236impl Connected<ClientConnInfo> for ClientConnInfo {
237    fn connect_info(target: ClientConnInfo) -> Self {
238        target
239    }
240}
241
242// This is only used for plaintext http - in other words, integration tests only.
243impl Connected<SocketAddr> for ClientConnInfo {
244    fn connect_info(connection_addr: SocketAddr) -> Self {
245        ClientConnInfo {
246            client_addr: connection_addr,
247            connection_addr,
248            client_cert: None,
249        }
250    }
251}