kanidmd_core/https/
mod.rs

1mod apidocs;
2pub(crate) mod cache_buster;
3pub(crate) mod errors;
4mod extractors;
5mod generic;
6mod javascript;
7mod manifest;
8pub(crate) mod middleware;
9mod oauth2;
10pub(crate) mod trace;
11mod v1;
12mod v1_domain;
13mod v1_oauth2;
14mod v1_scim;
15mod views;
16
17use self::extractors::ClientConnInfo;
18use self::javascript::*;
19use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
20use crate::config::{AddressSet, Configuration, ServerRole};
21use crate::CoreAction;
22use axum::{
23    body::Body,
24    extract::connect_info::IntoMakeServiceWithConnectInfo,
25    http::{HeaderMap, HeaderValue, Request},
26    middleware::{from_fn, from_fn_with_state},
27    response::Redirect,
28    routing::*,
29    Router,
30};
31use axum_extra::extract::cookie::CookieJar;
32use cidr::IpCidr;
33use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
34use futures::pin_mut;
35use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
36use hyper::body::Incoming;
37use hyper_util::rt::{TokioExecutor, TokioIo};
38use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
39use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
40use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
41use openssl::ssl::{Ssl, SslAcceptor};
42use serde::de::DeserializeOwned;
43use sketching::*;
44use std::fmt::Write;
45use std::io::ErrorKind;
46use std::path::PathBuf;
47use std::pin::Pin;
48use std::sync::Arc;
49use std::{net::SocketAddr, str::FromStr};
50use tokio::{
51    io::{AsyncRead, AsyncWrite},
52    net::{TcpListener, TcpStream},
53    sync::broadcast,
54    sync::mpsc,
55    task,
56};
57use tokio_openssl::SslStream;
58use tower::Service;
59use tower_http::{services::ServeDir, trace::TraceLayer};
60use url::Url;
61use uuid::Uuid;
62
63#[derive(Clone)]
64pub struct ServerState {
65    pub(crate) status_ref: &'static StatusActor,
66    pub(crate) qe_w_ref: &'static QueryServerWriteV1,
67    pub(crate) qe_r_ref: &'static QueryServerReadV1,
68    // Store the token management parts.
69    pub(crate) jws_signer: JwsHs256Signer,
70    pub(crate) trust_x_forward_for_ips: Option<Arc<AddressSet>>,
71    pub(crate) csp_header: HeaderValue,
72    pub(crate) origin: Url,
73    pub(crate) domain: String,
74    // This is set to true by default, and is only false on integration tests.
75    pub(crate) secure_cookies: bool,
76}
77
78impl ServerState {
79    /// Deserialize some input string validating that it was signed by our instance's
80    /// HMAC signer. This is used for short lived server-only sessions and context
81    /// data. This has applications in both accessing cookie content and header content.
82    fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
83        match JwsCompact::from_str(input) {
84            Ok(val) => match self.jws_signer.verify(&val) {
85                Ok(val) => val.from_json::<T>().ok(),
86                Err(err) => {
87                    error!(?err, "Failed to deserialise JWT from request");
88                    if matches!(err, JwtError::InvalidSignature) {
89                        // The server has an ephemeral in memory HMAC signer. This is important as
90                        // auth (login) sessions on one node shouldn't validate on another. Sessions
91                        // that are shared beween nodes use the internal ECDSA signer.
92                        //
93                        // But because of this if the server restarts it rolls the key. Additionally
94                        // it can occur if the load balancer isn't sticking sessions to the correct
95                        // node. That can cause this error. So we want to specifically call it out
96                        // to admins so they can investigate that the fault is occurring *outside*
97                        // of kanidm.
98                        warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
99                    }
100                    None
101                }
102            },
103            Err(_) => None,
104        }
105    }
106
107    #[instrument(level = "trace", skip_all)]
108    fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
109        // We see if there is a signed header copy first.
110        headers
111            .get(KSESSIONID)
112            .and_then(|hv| {
113                trace!("trying header");
114                // Get the first header value.
115                hv.to_str().ok()
116            })
117            .or_else(|| {
118                trace!("trying cookie");
119                jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
120            })
121            .and_then(|s| {
122                trace!(id_jws = %s);
123                self.deserialise_from_str::<Uuid>(s)
124            })
125    }
126}
127
128pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
129    let mut all_pages: Vec<JavaScriptFile> = Vec::new();
130
131    if !matches!(role, ServerRole::WriteReplicaNoUI) {
132        // let's set up the list of js module hashes
133        let pkg_path = env!("KANIDM_SERVER_UI_PKG_PATH").to_owned();
134
135        let filelist = [
136            "external/bootstrap.bundle.min.js",
137            "external/htmx.min.1.9.12.js",
138            "external/confetti.js",
139            "external/base64.js",
140            "modules/cred_update.mjs",
141            "pkhtml.js",
142            "style.js",
143        ];
144
145        for filepath in filelist {
146            match generate_integrity_hash(format!("{}/{}", pkg_path, filepath,)) {
147                Ok(hash) => {
148                    debug!("Integrity hash for {}: {}", filepath, hash);
149                    let js = JavaScriptFile { hash };
150                    all_pages.push(js)
151                }
152                Err(err) => {
153                    admin_error!(
154                        ?err,
155                        "Failed to generate integrity hash for {} - cancelling startup!",
156                        filepath
157                    );
158                    return Err(());
159                }
160            }
161        }
162    }
163    Ok(all_pages)
164}
165
166pub async fn create_https_server(
167    config: Configuration,
168    jws_signer: JwsHs256Signer,
169    status_ref: &'static StatusActor,
170    qe_w_ref: &'static QueryServerWriteV1,
171    qe_r_ref: &'static QueryServerReadV1,
172    server_message_tx: broadcast::Sender<CoreAction>,
173    maybe_tls_acceptor: Option<SslAcceptor>,
174    tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
175) -> Result<task::JoinHandle<()>, ()> {
176    let rx = server_message_tx.subscribe();
177
178    let all_js_files = get_js_files(config.role)?;
179    // set up the CSP headers
180    // script-src 'self'
181    //      'sha384-Zao7ExRXVZOJobzS/uMp0P1jtJz3TTqJU4nYXkdmsjpiVD+/wcwCyX7FGqRIqvIz'
182    //      'sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM';
183
184    let js_directives = all_js_files
185        .into_iter()
186        .map(|f| f.hash)
187        .collect::<Vec<String>>();
188
189    let js_checksums: String = js_directives
190        .iter()
191        .fold(String::new(), |mut output, value| {
192            let _ = write!(output, " 'sha384-{}'", value);
193            output
194        });
195
196    let csp_header = format!(
197        concat!(
198            "default-src 'self'; ",
199            "base-uri 'self' https:; ",
200            "form-action 'self' https:;",
201            "frame-ancestors 'none'; ",
202            "img-src 'self' data:; ",
203            "worker-src 'none'; ",
204            "script-src 'self' 'unsafe-eval'{};",
205        ),
206        js_checksums
207    );
208
209    let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
210        error!(?err, "Unable to generate content security policy");
211    })?;
212
213    let trust_x_forward_for_ips = config
214        .http_client_address_info
215        .trusted_x_forward_for()
216        .map(Arc::new);
217
218    let trusted_proxy_v2_ips = config
219        .http_client_address_info
220        .trusted_proxy_v2()
221        .map(Arc::new);
222
223    let origin = Url::parse(&config.origin)
224        // Should be impossible!
225        .map_err(|err| {
226            error!(?err, "Unable to parse origin URL - refusing to start. You must correct the value for origin. {:?}", config.origin);
227        })?;
228
229    let state = ServerState {
230        status_ref,
231        qe_w_ref,
232        qe_r_ref,
233        jws_signer,
234        trust_x_forward_for_ips,
235        csp_header,
236        origin,
237        domain: config.domain.clone(),
238        secure_cookies: config.integration_test_config.is_none(),
239    };
240
241    let static_routes = match config.role {
242        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
243            Router::new()
244                .route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
245                .route("/ui/images/domain", get(v1_domain::image_get))
246                .route("/manifest.webmanifest", get(manifest::manifest)) // skip_route_check
247                // Layers only apply to routes that are *already* added, not the ones
248                // added after.
249                .layer(middleware::compression::new())
250                .layer(from_fn(middleware::caching::cache_me_short))
251                .route("/", get(|| async { Redirect::to("/ui") }))
252                .nest("/ui", views::view_router())
253            // Can't compress on anything that changes
254        }
255        ServerRole::WriteReplicaNoUI => Router::new(),
256    };
257    let app = Router::new()
258        .merge(oauth2::route_setup(state.clone()))
259        .merge(v1_scim::route_setup())
260        .merge(v1::route_setup(state.clone()))
261        .route("/robots.txt", get(generic::robots_txt))
262        .route(
263            views::constants::Urls::WellKnownChangePassword.as_ref(),
264            get(generic::redirect_to_update_credentials),
265        );
266
267    let app = match config.role {
268        ServerRole::WriteReplicaNoUI => app,
269        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
270            let pkg_path = PathBuf::from(env!("KANIDM_SERVER_UI_PKG_PATH"));
271            if !pkg_path.exists() {
272                eprintln!(
273                    "Couldn't find htmx UI package path: ({}), quitting.",
274                    env!("KANIDM_SERVER_UI_PKG_PATH")
275                );
276                std::process::exit(1);
277            }
278            let pkg_router = Router::new()
279                .nest_service("/pkg", ServeDir::new(pkg_path))
280                // TODO: Add in the br precompress
281                .layer(from_fn(middleware::caching::cache_me_short));
282
283            app.merge(pkg_router)
284        }
285    };
286
287    // this sets up the default span which logs the URL etc.
288    let trace_layer = TraceLayer::new_for_http()
289        .make_span_with(trace::DefaultMakeSpanKanidmd::new())
290        // setting these to trace because all they do is print "started processing request", and we are already doing that enough!
291        .on_response(trace::DefaultOnResponseKanidmd::new());
292
293    let app = app
294        .merge(static_routes)
295        .layer(from_fn_with_state(
296            state.clone(),
297            middleware::security_headers::security_headers_layer,
298        ))
299        .layer(from_fn(middleware::version_middleware))
300        .layer(from_fn(
301            middleware::hsts_header::strict_transport_security_layer,
302        ));
303
304    // layer which checks the responses have a content-type of JSON when we're in debug mode
305    #[cfg(any(test, debug_assertions))]
306    let app = app.layer(from_fn(middleware::are_we_json_yet));
307
308    let app = app
309        .route("/status", get(generic::status))
310        // This must be the LAST middleware.
311        // This is because the last middleware here is the first to be entered and the last
312        // to be exited, and this middleware sets up ids' and other bits for for logging
313        // coherence to be maintained.
314        .layer(from_fn(middleware::kopid_middleware))
315        .merge(apidocs::router())
316        // this MUST be the last layer before with_state else the span never starts and everything breaks.
317        .layer(trace_layer)
318        .with_state(state)
319        // the connect_info bit here lets us pick up the remote address of the client
320        .into_make_service_with_connect_info::<ClientConnInfo>();
321
322    let addr = SocketAddr::from_str(&config.address).map_err(|err| {
323        error!(
324            "Failed to parse address ({:?}) from config: {:?}",
325            config.address, err
326        );
327    })?;
328
329    info!("Starting the web server...");
330
331    let listener = match TcpListener::bind(addr).await {
332        Ok(l) => l,
333        Err(err) => {
334            error!(?err, "Failed to bind tcp listener");
335            return Err(());
336        }
337    };
338
339    match maybe_tls_acceptor {
340        Some(tls_acceptor) => Ok(task::spawn(server_tls_loop(
341            tls_acceptor,
342            listener,
343            app,
344            rx,
345            server_message_tx,
346            tls_acceptor_reload_rx,
347            trusted_proxy_v2_ips,
348        ))),
349        None => Ok(task::spawn(server_plaintext_loop(
350            listener,
351            app,
352            rx,
353            trusted_proxy_v2_ips,
354        ))),
355    }
356}
357
358async fn server_tls_loop(
359    mut tls_acceptor: SslAcceptor,
360    listener: TcpListener,
361    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
362    mut rx: broadcast::Receiver<CoreAction>,
363    server_message_tx: broadcast::Sender<CoreAction>,
364    mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
365    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
366) {
367    pin_mut!(listener);
368
369    loop {
370        tokio::select! {
371            Ok(action) = rx.recv() => {
372                match action {
373                    CoreAction::Shutdown => break,
374                }
375            }
376            accept = listener.accept() => {
377                match accept {
378                    Ok((stream, addr)) => {
379                        let tls_acceptor = tls_acceptor.clone();
380                        let app = app.clone();
381                        task::spawn(handle_tls_conn(tls_acceptor, stream, app, addr, trusted_proxy_v2_ips.clone()));
382                    }
383                    Err(err) => {
384                        error!("Web server exited with {:?}", err);
385                        if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
386                            error!("Web server failed to send shutdown message! {:?}", err)
387                        };
388                        break;
389                    }
390                }
391            }
392            Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
393                std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
394                info!("Reloaded http tls acceptor");
395            }
396        }
397    }
398
399    info!("Stopped {}", super::TaskName::HttpsServer);
400}
401
402async fn server_plaintext_loop(
403    listener: TcpListener,
404    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
405    mut rx: broadcast::Receiver<CoreAction>,
406    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
407) {
408    pin_mut!(listener);
409
410    loop {
411        tokio::select! {
412            Ok(action) = rx.recv() => {
413                match action {
414                    CoreAction::Shutdown => break,
415                }
416            }
417            accept = listener.accept() => {
418                match accept {
419                    Ok((stream, addr)) => {
420                        let app = app.clone();
421                        task::spawn(handle_conn(stream, app, addr, trusted_proxy_v2_ips.clone()));
422                    }
423                    Err(err) => {
424                        error!("Web server exited with {:?}", err);
425                        break;
426                    }
427                }
428            }
429        }
430    }
431
432    info!("Stopped {}", super::TaskName::HttpsServer);
433}
434
435/// This handles an individual connection.
436pub(crate) async fn handle_conn(
437    stream: TcpStream,
438    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
439    connection_addr: SocketAddr,
440    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
441) -> Result<(), std::io::Error> {
442    let (stream, client_addr) =
443        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
444
445    let client_conn_info = ClientConnInfo {
446        connection_addr,
447        client_addr,
448        client_cert: None,
449    };
450
451    // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
452    // `TokioIo` converts between them.
453    let stream = TokioIo::new(stream);
454
455    process_client_hyper(stream, app, client_conn_info).await
456}
457
458/// This handles an individual connection.
459pub(crate) async fn handle_tls_conn(
460    acceptor: SslAcceptor,
461    stream: TcpStream,
462    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
463    connection_addr: SocketAddr,
464    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
465) -> Result<(), std::io::Error> {
466    let (stream, client_addr) =
467        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
468
469    let ssl = Ssl::new(acceptor.context()).map_err(|e| {
470        error!("Failed to create TLS context: {:?}", e);
471        std::io::Error::from(ErrorKind::ConnectionAborted)
472    })?;
473
474    let mut tls_stream = SslStream::new(ssl, stream).map_err(|err| {
475        error!(?err, "Failed to create TLS stream");
476        std::io::Error::from(ErrorKind::ConnectionAborted)
477    })?;
478
479    match SslStream::accept(Pin::new(&mut tls_stream)).await {
480        Ok(_) => {
481            // Process the client cert (if any)
482            let client_cert = if let Some(peer_cert) = tls_stream.ssl().peer_certificate() {
483                // TODO: This is where we should be checking the CRL!!!
484
485                // Extract the cert from openssl to x509-cert which is a better
486                // parser to handle the various extensions.
487
488                let cert_der = peer_cert.to_der().map_err(|ossl_err| {
489                    error!(?ossl_err, "unable to process x509 certificate as DER");
490                    std::io::Error::from(ErrorKind::ConnectionAborted)
491                })?;
492
493                let certificate = Certificate::from_der(&cert_der).map_err(|ossl_err| {
494                    error!(?ossl_err, "unable to process DER certificate to x509");
495                    std::io::Error::from(ErrorKind::ConnectionAborted)
496                })?;
497
498                let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
499                    error!("subject public key bitstring is not octet aligned");
500                    std::io::Error::from(ErrorKind::ConnectionAborted)
501                })?;
502
503                Some(ClientCertInfo {
504                    public_key_s256,
505                    certificate,
506                })
507            } else {
508                None
509            };
510
511            let client_conn_info = ClientConnInfo {
512                connection_addr,
513                client_addr,
514                client_cert,
515            };
516
517            // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
518            // `TokioIo` converts between them.
519            let stream = TokioIo::new(tls_stream);
520
521            process_client_hyper(stream, app, client_conn_info).await
522        }
523        Err(error) => {
524            trace!("Failed to handle connection: {:?}", error);
525            Ok(())
526        }
527    }
528}
529
530async fn process_client_addr(
531    stream: TcpStream,
532    connection_addr: SocketAddr,
533    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
534) -> Result<(TcpStream, SocketAddr), std::io::Error> {
535    let enable_proxy_v2_hdr = trusted_proxy_v2_ips
536        .map(|trusted| {
537            trusted
538                .iter()
539                .any(|ip_cidr| ip_cidr.contains(&connection_addr.ip()))
540        })
541        .unwrap_or_default();
542
543    let (stream, client_addr) = if enable_proxy_v2_hdr {
544        match ProxyHdrV2::parse_from_read(stream).await {
545            Ok((stream, hdr)) => {
546                let remote_socket_addr = match hdr.to_remote_addr() {
547                    RemoteAddress::Local => {
548                        debug!("PROXY protocol liveness check - will not contain client data");
549                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
550                    }
551                    RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
552                    RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
553                    remote_addr => {
554                        error!(?remote_addr, "remote address in proxy header is invalid");
555                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
556                    }
557                };
558
559                (stream, remote_socket_addr)
560            }
561            Err(err) => {
562                error!(?connection_addr, ?err, "Unable to process proxy v2 header");
563                return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
564            }
565        }
566    } else {
567        (stream, connection_addr)
568    };
569
570    Ok((stream, client_addr))
571}
572
573async fn process_client_hyper<T>(
574    stream: TokioIo<T>,
575    mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
576    client_conn_info: ClientConnInfo,
577) -> Result<(), std::io::Error>
578where
579    T: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
580{
581    debug!(?client_conn_info);
582
583    let svc = tower::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
584        &mut app,
585        client_conn_info,
586    );
587
588    let svc = svc.await.map_err(|e| {
589        error!("Failed to build HTTP response: {:?}", e);
590        std::io::Error::from(ErrorKind::Other)
591    })?;
592
593    // Hyper also has its own `Service` trait and doesn't use tower. We can use
594    // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
595    // `tower::Service::call`.
596    let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
597        // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
598        // tower's `Service` requires `&mut self`.
599        //
600        // We don't need to call `poll_ready` since `Router` is always ready.
601        svc.clone().call(request)
602    });
603
604    hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
605        .serve_connection_with_upgrades(stream, hyper_service)
606        .await
607        .map_err(|e| {
608            debug!("Failed to complete connection: {:?}", e);
609            std::io::Error::from(ErrorKind::ConnectionAborted)
610        })
611}