kanidmd_core/https/
mod.rs

1mod apidocs;
2pub(crate) mod cache_buster;
3pub(crate) mod errors;
4mod extractors;
5mod generic;
6mod javascript;
7mod manifest;
8pub(crate) mod middleware;
9mod oauth2;
10pub(crate) mod trace;
11mod v1;
12mod v1_domain;
13mod v1_oauth2;
14mod v1_scim;
15mod views;
16
17use self::extractors::ClientConnInfo;
18use self::javascript::*;
19use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
20use crate::config::{AddressSet, Configuration, ServerRole};
21use crate::CoreAction;
22use axum::{
23    body::Body,
24    extract::connect_info::IntoMakeServiceWithConnectInfo,
25    http::{HeaderMap, HeaderValue, Request, StatusCode},
26    middleware::{from_fn, from_fn_with_state},
27    response::{IntoResponse, Redirect, Response},
28    routing::*,
29    Router,
30};
31use axum_extra::extract::cookie::CookieJar;
32use cidr::IpCidr;
33use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
34use futures::pin_mut;
35use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
36use hyper::body::Incoming;
37use hyper_util::rt::{TokioExecutor, TokioIo};
38use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
39use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
40use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
41use serde::de::DeserializeOwned;
42use sketching::*;
43use std::fmt::Write;
44use std::io::ErrorKind;
45use std::path::PathBuf;
46use std::sync::Arc;
47use std::{
48    net::{IpAddr, SocketAddr},
49    str::FromStr,
50};
51use tokio::{
52    io::{AsyncRead, AsyncWrite},
53    net::{TcpListener, TcpStream},
54    sync::broadcast,
55    sync::mpsc,
56    task,
57};
58use tokio_rustls::TlsAcceptor;
59use tower::Service;
60use tower_http::{services::ServeDir, trace::TraceLayer};
61use url::Url;
62use uuid::Uuid;
63
64#[derive(Clone)]
65pub struct ServerState {
66    pub(crate) status_ref: &'static StatusActor,
67    pub(crate) qe_w_ref: &'static QueryServerWriteV1,
68    pub(crate) qe_r_ref: &'static QueryServerReadV1,
69    // Store the token management parts.
70    pub(crate) jws_signer: JwsHs256Signer,
71    pub(crate) trust_x_forward_for_ips: Option<Arc<AddressSet>>,
72    pub(crate) csp_header: HeaderValue,
73    pub(crate) csp_header_no_form_action: HeaderValue,
74    pub(crate) origin: Url,
75    pub(crate) domain: String,
76    // This is set to true by default, and is only false on integration tests.
77    pub(crate) secure_cookies: bool,
78}
79
80impl ServerState {
81    /// Deserialize some input string validating that it was signed by our instance's
82    /// HMAC signer. This is used for short lived server-only sessions and context
83    /// data. This has applications in both accessing cookie content and header content.
84    fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
85        match JwsCompact::from_str(input) {
86            Ok(val) => match self.jws_signer.verify(&val) {
87                Ok(val) => val.from_json::<T>().ok(),
88                Err(err) => {
89                    error!(?err, "Failed to deserialise JWT from request");
90                    if matches!(err, JwtError::InvalidSignature) {
91                        // The server has an ephemeral in memory HMAC signer. This is important as
92                        // auth (login) sessions on one node shouldn't validate on another. Sessions
93                        // that are shared between nodes use the internal ECDSA signer.
94                        //
95                        // But because of this if the server restarts it rolls the key. Additionally
96                        // it can occur if the load balancer isn't sticking sessions to the correct
97                        // node. That can cause this error. So we want to specifically call it out
98                        // to admins so they can investigate that the fault is occurring *outside*
99                        // of kanidm.
100                        warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
101                    }
102                    None
103                }
104            },
105            Err(_) => None,
106        }
107    }
108
109    #[instrument(level = "trace", skip_all)]
110    fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
111        // We see if there is a signed header copy first.
112        headers
113            .get(KSESSIONID)
114            .and_then(|hv| {
115                trace!("trying header");
116                // Get the first header value.
117                hv.to_str().ok()
118            })
119            .or_else(|| {
120                trace!("trying cookie");
121                jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
122            })
123            .and_then(|s| {
124                trace!(id_jws = %s);
125                self.deserialise_from_str::<Uuid>(s)
126            })
127    }
128}
129
130pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
131    let mut all_pages: Vec<JavaScriptFile> = Vec::new();
132
133    if !matches!(role, ServerRole::WriteReplicaNoUI) {
134        // let's set up the list of js module hashes
135        let pkg_path = env!("KANIDM_SERVER_UI_PKG_PATH").to_owned();
136
137        let filelist = [
138            "external/bootstrap.bundle.min.js",
139            "external/htmx.min.1.9.12.js",
140            "external/confetti.js",
141            "external/base64.js",
142            "modules/cred_update.mjs",
143            "pkhtml.js",
144            "style.js",
145        ];
146
147        for filepath in filelist {
148            match generate_integrity_hash(format!("{pkg_path}/{filepath}",)) {
149                Ok(hash) => {
150                    debug!("Integrity hash for {}: {}", filepath, hash);
151                    let js = JavaScriptFile { hash };
152                    all_pages.push(js)
153                }
154                Err(err) => {
155                    admin_error!(
156                        ?err,
157                        "Failed to generate integrity hash for {} - cancelling startup!",
158                        filepath
159                    );
160                    return Err(());
161                }
162            }
163        }
164    }
165    Ok(all_pages)
166}
167
168async fn handler_404() -> Response {
169    (StatusCode::NOT_FOUND, "Route not found").into_response()
170}
171
172pub async fn create_https_server(
173    config: Configuration,
174    jws_signer: JwsHs256Signer,
175    status_ref: &'static StatusActor,
176    qe_w_ref: &'static QueryServerWriteV1,
177    qe_r_ref: &'static QueryServerReadV1,
178    server_message_tx: broadcast::Sender<CoreAction>,
179    maybe_tls_acceptor: Option<TlsAcceptor>,
180    tls_acceptor_reload_rx: mpsc::Receiver<TlsAcceptor>,
181) -> Result<task::JoinHandle<()>, ()> {
182    let rx = server_message_tx.subscribe();
183
184    let all_js_files = get_js_files(config.role)?;
185    // set up the CSP headers
186    // script-src 'self'
187    //      'sha384-Zao7ExRXVZOJobzS/uMp0P1jtJz3TTqJU4nYXkdmsjpiVD+/wcwCyX7FGqRIqvIz'
188    //      'sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM';
189
190    let js_directives = all_js_files
191        .into_iter()
192        .map(|f| f.hash)
193        .collect::<Vec<String>>();
194
195    let js_checksums: String = js_directives
196        .iter()
197        .fold(String::new(), |mut output, value| {
198            let _ = write!(output, " 'sha384-{value}'");
199            output
200        });
201
202    let csp_header = format!(
203        concat!(
204            "default-src 'self'; ",
205            "base-uri 'self' https:; ",
206            "form-action 'self'; ",
207            "frame-ancestors 'none'; ",
208            "img-src 'self' data:; ",
209            "worker-src 'none'; ",
210            "script-src 'self' 'unsafe-eval'{};",
211        ),
212        js_checksums
213    );
214
215    let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
216        error!(?err, "Unable to generate content security policy");
217    })?;
218
219    // Omit form action - form action is interpreted by chrome to also control valid
220    // redirect targets on submit. This breaks oauth2 in many cases.
221    //
222    // Normally this would be considered BAD to remove a CSP control to make Oauth2 work
223    // but we need to consider the primary attack form-action protects from - open redirectors
224    // in the form submission. Since the paths that use this header do NOT have open
225    // redirectors, we are safe to remove the form-action directive.
226    let csp_header_no_form_action = format!(
227        concat!(
228            "default-src 'self'; ",
229            "base-uri 'self' https:; ",
230            "frame-ancestors 'none'; ",
231            "img-src 'self' data:; ",
232            "worker-src 'none'; ",
233            "script-src 'self' 'unsafe-eval'{};",
234        ),
235        js_checksums
236    );
237
238    let csp_header_no_form_action =
239        HeaderValue::from_str(&csp_header_no_form_action).map_err(|err| {
240            error!(
241                ?err,
242                "Unable to generate content security policy with no form action"
243            );
244        })?;
245
246    let trust_x_forward_for_ips = config
247        .http_client_address_info
248        .trusted_x_forward_for()
249        .map(Arc::new);
250
251    let trusted_proxy_v2_ips = config
252        .http_client_address_info
253        .trusted_proxy_v2()
254        .map(Arc::new);
255
256    let state = ServerState {
257        status_ref,
258        qe_w_ref,
259        qe_r_ref,
260        jws_signer,
261        trust_x_forward_for_ips,
262        csp_header,
263        csp_header_no_form_action,
264        origin: config.origin,
265        domain: config.domain.clone(),
266        secure_cookies: config.integration_test_config.is_none(),
267    };
268
269    let static_routes = match config.role {
270        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
271            Router::new()
272                .route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
273                .route("/ui/images/domain", get(v1_domain::image_get))
274                .route("/manifest.webmanifest", get(manifest::manifest)) // skip_route_check
275                // Layers only apply to routes that are *already* added, not the ones
276                // added after.
277                .layer(middleware::compression::new())
278                .layer(from_fn(middleware::caching::cache_me_short))
279                .route("/", get(|| async { Redirect::to("/ui") }))
280                .nest("/ui", views::view_router(state.clone()))
281            // Can't compress on anything that changes
282        }
283        ServerRole::WriteReplicaNoUI => Router::new(),
284    };
285    let app = Router::new()
286        .merge(oauth2::route_setup(state.clone()))
287        .merge(v1_scim::route_setup())
288        .merge(v1::route_setup(state.clone()))
289        .route("/robots.txt", get(generic::robots_txt))
290        .route(
291            views::constants::Urls::WellKnownChangePassword.as_ref(),
292            get(generic::redirect_to_update_credentials),
293        );
294
295    let app = match config.role {
296        ServerRole::WriteReplicaNoUI => app,
297        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
298            let pkg_path = PathBuf::from(env!("KANIDM_SERVER_UI_PKG_PATH"));
299            if !pkg_path.exists() {
300                eprintln!(
301                    "Couldn't find htmx UI package path: ({}), quitting.",
302                    env!("KANIDM_SERVER_UI_PKG_PATH")
303                );
304                std::process::exit(1);
305            }
306            let pkg_router = Router::new()
307                .nest_service("/pkg", ServeDir::new(pkg_path))
308                // TODO: Add in the br precompress
309                .layer(from_fn(middleware::caching::cache_me_short));
310
311            app.merge(pkg_router)
312        }
313    };
314
315    // this sets up the default span which logs the URL etc.
316    let trace_layer = TraceLayer::new_for_http()
317        .make_span_with(trace::DefaultMakeSpanKanidmd::new())
318        // setting these to trace because all they do is print "started processing request", and we are already doing that enough!
319        .on_response(trace::DefaultOnResponseKanidmd::new());
320
321    let app = app
322        .merge(static_routes)
323        .layer(from_fn_with_state(
324            state.clone(),
325            middleware::security_headers::security_headers_layer,
326        ))
327        .layer(from_fn(middleware::version_middleware))
328        .layer(from_fn(
329            middleware::hsts_header::strict_transport_security_layer,
330        ));
331
332    // layer which checks the responses have a content-type of JSON when we're in debug mode
333    #[cfg(any(test, debug_assertions))]
334    let app = app.layer(from_fn(middleware::are_we_json_yet));
335
336    let app = app
337        .route("/status", get(generic::status))
338        // 404 handler
339        .fallback(handler_404)
340        // This must be the LAST middleware.
341        // This is because the last middleware here is the first to be entered and the last
342        // to be exited, and this middleware sets up ids' and other bits for for logging
343        // coherence to be maintained.
344        .layer(from_fn_with_state(
345            state.clone(),
346            middleware::ip_address_middleware,
347        ))
348        .layer(from_fn(middleware::kopid_middleware))
349        .merge(apidocs::router())
350        // this MUST be the last layer before with_state else the span never starts and everything breaks.
351        .layer(trace_layer)
352        .with_state(state)
353        // the connect_info bit here lets us pick up the remote address of the client
354        .into_make_service_with_connect_info::<ClientConnInfo>();
355
356    let addr = SocketAddr::from_str(&config.address).map_err(|err| {
357        error!(
358            "Failed to parse address ({:?}) from config: {:?}",
359            config.address, err
360        );
361    })?;
362
363    info!("Starting the web server...");
364
365    let listener = match TcpListener::bind(addr).await {
366        Ok(l) => l,
367        Err(err) => {
368            error!(?err, "Failed to bind tcp listener");
369            return Err(());
370        }
371    };
372
373    match maybe_tls_acceptor {
374        Some(tls_acceptor) => Ok(task::spawn(server_tls_loop(
375            tls_acceptor,
376            listener,
377            app,
378            rx,
379            server_message_tx,
380            tls_acceptor_reload_rx,
381            trusted_proxy_v2_ips,
382        ))),
383        None => Ok(task::spawn(server_plaintext_loop(
384            listener,
385            app,
386            rx,
387            trusted_proxy_v2_ips,
388        ))),
389    }
390}
391
392async fn server_tls_loop(
393    mut tls_acceptor: TlsAcceptor,
394    listener: TcpListener,
395    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
396    mut rx: broadcast::Receiver<CoreAction>,
397    server_message_tx: broadcast::Sender<CoreAction>,
398    mut tls_acceptor_reload_rx: mpsc::Receiver<TlsAcceptor>,
399    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
400) {
401    pin_mut!(listener);
402
403    loop {
404        tokio::select! {
405            Ok(action) = rx.recv() => {
406                match action {
407                    CoreAction::Shutdown => break,
408                }
409            }
410            accept = listener.accept() => {
411                match accept {
412                    Ok((stream, addr)) => {
413                        let tls_acceptor = tls_acceptor.clone();
414                        let app = app.clone();
415                        task::spawn(handle_tls_conn(tls_acceptor, stream, app, addr, trusted_proxy_v2_ips.clone()));
416                    }
417                    Err(err) => {
418                        error!("Web server exited with {:?}", err);
419                        if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
420                            error!("Web server failed to send shutdown message! {:?}", err)
421                        };
422                        break;
423                    }
424                }
425            }
426            Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
427                std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
428                info!("Reloaded http tls acceptor");
429            }
430        }
431    }
432
433    info!("Stopped {}", super::TaskName::HttpsServer);
434}
435
436async fn server_plaintext_loop(
437    listener: TcpListener,
438    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
439    mut rx: broadcast::Receiver<CoreAction>,
440    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
441) {
442    pin_mut!(listener);
443
444    loop {
445        tokio::select! {
446            Ok(action) = rx.recv() => {
447                match action {
448                    CoreAction::Shutdown => break,
449                }
450            }
451            accept = listener.accept() => {
452                match accept {
453                    Ok((stream, addr)) => {
454                        let app = app.clone();
455                        task::spawn(handle_conn(stream, app, addr, trusted_proxy_v2_ips.clone()));
456                    }
457                    Err(err) => {
458                        error!("Web server exited with {:?}", err);
459                        break;
460                    }
461                }
462            }
463        }
464    }
465
466    info!("Stopped {}", super::TaskName::HttpsServer);
467}
468
469/// This handles an individual connection.
470pub(crate) async fn handle_conn(
471    stream: TcpStream,
472    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
473    connection_addr: SocketAddr,
474    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
475) -> Result<(), std::io::Error> {
476    let (stream, client_ip_addr) =
477        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
478
479    let client_conn_info = ClientConnInfo {
480        connection_addr,
481        client_ip_addr,
482        client_cert: None,
483    };
484
485    // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
486    // `TokioIo` converts between them.
487    let stream = TokioIo::new(stream);
488
489    process_client_hyper(stream, app, client_conn_info).await
490}
491
492/// This handles an individual connection.
493pub(crate) async fn handle_tls_conn(
494    acceptor: TlsAcceptor,
495    stream: TcpStream,
496    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
497    connection_addr: SocketAddr,
498    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
499) -> Result<(), std::io::Error> {
500    let (stream, client_ip_addr) =
501        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
502
503    let tls_stream = acceptor.accept(stream).await.map_err(|err| {
504        error!(?err, "Failed to create TLS stream");
505        std::io::Error::from(ErrorKind::ConnectionAborted)
506    })?;
507
508    let maybe_peer_cert = tls_stream
509        .get_ref()
510        .1
511        .peer_certificates()
512        // The first certificate relates to the peer.
513        .and_then(|peer_certs| peer_certs.first());
514
515    // Process the client cert (if any)
516    let client_cert = if let Some(peer_cert) = maybe_peer_cert {
517        // We don't need to check the CRL here - it's already completed as part of the
518        // TLS connection establishment process.
519
520        // Extract the cert from rustls DER to x509-cert which is a better
521        // parser to handle the various extensions.
522        let certificate = Certificate::from_der(peer_cert).map_err(|ossl_err| {
523            error!(?ossl_err, "unable to process DER certificate to x509");
524            std::io::Error::from(ErrorKind::ConnectionAborted)
525        })?;
526
527        let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
528            error!("subject public key bitstring is not octet aligned");
529            std::io::Error::from(ErrorKind::ConnectionAborted)
530        })?;
531
532        Some(ClientCertInfo {
533            public_key_s256,
534            certificate,
535        })
536    } else {
537        None
538    };
539
540    let client_conn_info = ClientConnInfo {
541        connection_addr,
542        client_ip_addr,
543        client_cert,
544    };
545
546    // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
547    // `TokioIo` converts between them.
548    let stream = TokioIo::new(tls_stream);
549
550    process_client_hyper(stream, app, client_conn_info).await
551}
552
553async fn process_client_addr(
554    stream: TcpStream,
555    connection_addr: SocketAddr,
556    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
557) -> Result<(TcpStream, IpAddr), std::io::Error> {
558    let enable_proxy_v2_hdr = trusted_proxy_v2_ips
559        .map(|trusted| {
560            trusted
561                .iter()
562                .any(|ip_cidr| ip_cidr.contains(&connection_addr.ip().to_canonical()))
563        })
564        .unwrap_or_default();
565
566    let (stream, client_addr) = if enable_proxy_v2_hdr {
567        match ProxyHdrV2::parse_from_read(stream).await {
568            Ok((stream, hdr)) => {
569                let remote_socket_addr = match hdr.to_remote_addr() {
570                    RemoteAddress::Local => {
571                        debug!("PROXY protocol liveness check - will not contain client data");
572                        // This is a check from the proxy, so just use the connection address.
573                        connection_addr
574                    }
575                    RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
576                    RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
577                    remote_addr => {
578                        error!(?remote_addr, "remote address in proxy header is invalid");
579                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
580                    }
581                };
582
583                (stream, remote_socket_addr)
584            }
585            Err(err) => {
586                error!(?connection_addr, ?err, "Unable to process proxy v2 header");
587                return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
588            }
589        }
590    } else {
591        (stream, connection_addr)
592    };
593
594    Ok((stream, client_addr.ip()))
595}
596
597async fn process_client_hyper<T>(
598    stream: TokioIo<T>,
599    mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
600    client_conn_info: ClientConnInfo,
601) -> Result<(), std::io::Error>
602where
603    T: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
604{
605    debug!(?client_conn_info);
606
607    let svc = tower::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
608        &mut app,
609        client_conn_info,
610    );
611
612    let svc = svc.await.map_err(|e| {
613        error!("Failed to build HTTP response: {:?}", e);
614        std::io::Error::from(ErrorKind::Other)
615    })?;
616
617    // Hyper also has its own `Service` trait and doesn't use tower. We can use
618    // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
619    // `tower::Service::call`.
620    let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
621        // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
622        // tower's `Service` requires `&mut self`.
623        //
624        // We don't need to call `poll_ready` since `Router` is always ready.
625        svc.clone().call(request)
626    });
627
628    hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
629        .serve_connection_with_upgrades(stream, hyper_service)
630        .await
631        .map_err(|e| {
632            debug!("Failed to complete connection: {:?}", e);
633            std::io::Error::from(ErrorKind::ConnectionAborted)
634        })
635}