kanidmd_core/https/
mod.rs

1mod apidocs;
2pub(crate) mod cache_buster;
3pub(crate) mod errors;
4mod extractors;
5mod generic;
6mod javascript;
7mod manifest;
8pub(crate) mod middleware;
9mod oauth2;
10pub(crate) mod trace;
11mod v1;
12mod v1_domain;
13mod v1_oauth2;
14mod v1_scim;
15mod views;
16
17use self::extractors::ClientConnInfo;
18use self::javascript::*;
19use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
20use crate::config::{AddressSet, Configuration, ServerRole};
21use crate::CoreAction;
22use axum::{
23    body::Body,
24    extract::connect_info::IntoMakeServiceWithConnectInfo,
25    http::{HeaderMap, HeaderValue, Request},
26    middleware::{from_fn, from_fn_with_state},
27    response::Redirect,
28    routing::*,
29    Router,
30};
31use axum_extra::extract::cookie::CookieJar;
32use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
33use futures::pin_mut;
34use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
35use hashbrown::HashSet;
36use hyper::body::Incoming;
37use hyper_util::rt::{TokioExecutor, TokioIo};
38use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
39use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
40use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
41use openssl::ssl::{Ssl, SslAcceptor};
42use serde::de::DeserializeOwned;
43use sketching::*;
44use std::fmt::Write;
45use std::io::ErrorKind;
46use std::net::IpAddr;
47use std::path::PathBuf;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::{net::SocketAddr, str::FromStr};
51use tokio::{
52    io::{AsyncRead, AsyncWrite},
53    net::{TcpListener, TcpStream},
54    sync::broadcast,
55    sync::mpsc,
56    task,
57};
58use tokio_openssl::SslStream;
59use tower::Service;
60use tower_http::{services::ServeDir, trace::TraceLayer};
61use url::Url;
62use uuid::Uuid;
63
64#[derive(Clone)]
65pub struct ServerState {
66    pub(crate) status_ref: &'static StatusActor,
67    pub(crate) qe_w_ref: &'static QueryServerWriteV1,
68    pub(crate) qe_r_ref: &'static QueryServerReadV1,
69    // Store the token management parts.
70    pub(crate) jws_signer: JwsHs256Signer,
71    pub(crate) trust_x_forward_for_ips: Option<Arc<AddressSet>>,
72    pub(crate) csp_header: HeaderValue,
73    pub(crate) origin: Url,
74    pub(crate) domain: String,
75    // This is set to true by default, and is only false on integration tests.
76    pub(crate) secure_cookies: bool,
77}
78
79impl ServerState {
80    /// Deserialize some input string validating that it was signed by our instance's
81    /// HMAC signer. This is used for short lived server-only sessions and context
82    /// data. This has applications in both accessing cookie content and header content.
83    fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
84        match JwsCompact::from_str(input) {
85            Ok(val) => match self.jws_signer.verify(&val) {
86                Ok(val) => val.from_json::<T>().ok(),
87                Err(err) => {
88                    error!(?err, "Failed to deserialise JWT from request");
89                    if matches!(err, JwtError::InvalidSignature) {
90                        // The server has an ephemeral in memory HMAC signer. This is important as
91                        // auth (login) sessions on one node shouldn't validate on another. Sessions
92                        // that are shared beween nodes use the internal ECDSA signer.
93                        //
94                        // But because of this if the server restarts it rolls the key. Additionally
95                        // it can occur if the load balancer isn't sticking sessions to the correct
96                        // node. That can cause this error. So we want to specifically call it out
97                        // to admins so they can investigate that the fault is occurring *outside*
98                        // of kanidm.
99                        warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
100                    }
101                    None
102                }
103            },
104            Err(_) => None,
105        }
106    }
107
108    #[instrument(level = "trace", skip_all)]
109    fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
110        // We see if there is a signed header copy first.
111        headers
112            .get(KSESSIONID)
113            .and_then(|hv| {
114                trace!("trying header");
115                // Get the first header value.
116                hv.to_str().ok()
117            })
118            .or_else(|| {
119                trace!("trying cookie");
120                jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
121            })
122            .and_then(|s| {
123                trace!(id_jws = %s);
124                self.deserialise_from_str::<Uuid>(s)
125            })
126    }
127}
128
129pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
130    let mut all_pages: Vec<JavaScriptFile> = Vec::new();
131
132    if !matches!(role, ServerRole::WriteReplicaNoUI) {
133        // let's set up the list of js module hashes
134        let pkg_path = env!("KANIDM_SERVER_UI_PKG_PATH").to_owned();
135
136        let filelist = [
137            "external/bootstrap.bundle.min.js",
138            "external/htmx.min.1.9.12.js",
139            "external/confetti.js",
140            "external/base64.js",
141            "modules/cred_update.mjs",
142            "pkhtml.js",
143            "style.js",
144        ];
145
146        for filepath in filelist {
147            match generate_integrity_hash(format!("{}/{}", pkg_path, filepath,)) {
148                Ok(hash) => {
149                    debug!("Integrity hash for {}: {}", filepath, hash);
150                    let js = JavaScriptFile { hash };
151                    all_pages.push(js)
152                }
153                Err(err) => {
154                    admin_error!(
155                        ?err,
156                        "Failed to generate integrity hash for {} - cancelling startup!",
157                        filepath
158                    );
159                    return Err(());
160                }
161            }
162        }
163    }
164    Ok(all_pages)
165}
166
167pub async fn create_https_server(
168    config: Configuration,
169    jws_signer: JwsHs256Signer,
170    status_ref: &'static StatusActor,
171    qe_w_ref: &'static QueryServerWriteV1,
172    qe_r_ref: &'static QueryServerReadV1,
173    server_message_tx: broadcast::Sender<CoreAction>,
174    maybe_tls_acceptor: Option<SslAcceptor>,
175    tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
176) -> Result<task::JoinHandle<()>, ()> {
177    let rx = server_message_tx.subscribe();
178
179    let all_js_files = get_js_files(config.role)?;
180    // set up the CSP headers
181    // script-src 'self'
182    //      'sha384-Zao7ExRXVZOJobzS/uMp0P1jtJz3TTqJU4nYXkdmsjpiVD+/wcwCyX7FGqRIqvIz'
183    //      'sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM';
184
185    let js_directives = all_js_files
186        .into_iter()
187        .map(|f| f.hash)
188        .collect::<Vec<String>>();
189
190    let js_checksums: String = js_directives
191        .iter()
192        .fold(String::new(), |mut output, value| {
193            let _ = write!(output, " 'sha384-{}'", value);
194            output
195        });
196
197    let csp_header = format!(
198        concat!(
199            "default-src 'self'; ",
200            "base-uri 'self' https:; ",
201            "form-action 'self' https:;",
202            "frame-ancestors 'none'; ",
203            "img-src 'self' data:; ",
204            "worker-src 'none'; ",
205            "script-src 'self' 'unsafe-eval'{};",
206        ),
207        js_checksums
208    );
209
210    let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
211        error!(?err, "Unable to generate content security policy");
212    })?;
213
214    let trust_x_forward_for_ips = config
215        .http_client_address_info
216        .trusted_x_forward_for()
217        .map(Arc::new);
218
219    let trusted_proxy_v2_ips = config
220        .http_client_address_info
221        .trusted_proxy_v2()
222        .map(Arc::new);
223
224    let origin = Url::parse(&config.origin)
225        // Should be impossible!
226        .map_err(|err| {
227            error!(?err, "Unable to parse origin URL - refusing to start. You must correct the value for origin. {:?}", config.origin);
228        })?;
229
230    let state = ServerState {
231        status_ref,
232        qe_w_ref,
233        qe_r_ref,
234        jws_signer,
235        trust_x_forward_for_ips,
236        csp_header,
237        origin,
238        domain: config.domain.clone(),
239        secure_cookies: config.integration_test_config.is_none(),
240    };
241
242    let static_routes = match config.role {
243        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
244            Router::new()
245                .route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
246                .route("/ui/images/domain", get(v1_domain::image_get))
247                .route("/manifest.webmanifest", get(manifest::manifest)) // skip_route_check
248                // Layers only apply to routes that are *already* added, not the ones
249                // added after.
250                .layer(middleware::compression::new())
251                .layer(from_fn(middleware::caching::cache_me_short))
252                .route("/", get(|| async { Redirect::to("/ui") }))
253                .nest("/ui", views::view_router())
254            // Can't compress on anything that changes
255        }
256        ServerRole::WriteReplicaNoUI => Router::new(),
257    };
258    let app = Router::new()
259        .merge(oauth2::route_setup(state.clone()))
260        .merge(v1_scim::route_setup())
261        .merge(v1::route_setup(state.clone()))
262        .route("/robots.txt", get(generic::robots_txt))
263        .route(
264            views::constants::Urls::WellKnownChangePassword.as_ref(),
265            get(generic::redirect_to_update_credentials),
266        );
267
268    let app = match config.role {
269        ServerRole::WriteReplicaNoUI => app,
270        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
271            let pkg_path = PathBuf::from(env!("KANIDM_SERVER_UI_PKG_PATH"));
272            if !pkg_path.exists() {
273                eprintln!(
274                    "Couldn't find htmx UI package path: ({}), quitting.",
275                    env!("KANIDM_SERVER_UI_PKG_PATH")
276                );
277                std::process::exit(1);
278            }
279            let pkg_router = Router::new()
280                .nest_service("/pkg", ServeDir::new(pkg_path))
281                // TODO: Add in the br precompress
282                .layer(from_fn(middleware::caching::cache_me_short));
283
284            app.merge(pkg_router)
285        }
286    };
287
288    // this sets up the default span which logs the URL etc.
289    let trace_layer = TraceLayer::new_for_http()
290        .make_span_with(trace::DefaultMakeSpanKanidmd::new())
291        // setting these to trace because all they do is print "started processing request", and we are already doing that enough!
292        .on_response(trace::DefaultOnResponseKanidmd::new());
293
294    let app = app
295        .merge(static_routes)
296        .layer(from_fn_with_state(
297            state.clone(),
298            middleware::security_headers::security_headers_layer,
299        ))
300        .layer(from_fn(middleware::version_middleware))
301        .layer(from_fn(
302            middleware::hsts_header::strict_transport_security_layer,
303        ));
304
305    // layer which checks the responses have a content-type of JSON when we're in debug mode
306    #[cfg(any(test, debug_assertions))]
307    let app = app.layer(from_fn(middleware::are_we_json_yet));
308
309    let app = app
310        .route("/status", get(generic::status))
311        // This must be the LAST middleware.
312        // This is because the last middleware here is the first to be entered and the last
313        // to be exited, and this middleware sets up ids' and other bits for for logging
314        // coherence to be maintained.
315        .layer(from_fn(middleware::kopid_middleware))
316        .merge(apidocs::router())
317        // this MUST be the last layer before with_state else the span never starts and everything breaks.
318        .layer(trace_layer)
319        .with_state(state)
320        // the connect_info bit here lets us pick up the remote address of the client
321        .into_make_service_with_connect_info::<ClientConnInfo>();
322
323    let addr = SocketAddr::from_str(&config.address).map_err(|err| {
324        error!(
325            "Failed to parse address ({:?}) from config: {:?}",
326            config.address, err
327        );
328    })?;
329
330    info!("Starting the web server...");
331
332    let listener = match TcpListener::bind(addr).await {
333        Ok(l) => l,
334        Err(err) => {
335            error!(?err, "Failed to bind tcp listener");
336            return Err(());
337        }
338    };
339
340    match maybe_tls_acceptor {
341        Some(tls_acceptor) => Ok(task::spawn(server_tls_loop(
342            tls_acceptor,
343            listener,
344            app,
345            rx,
346            server_message_tx,
347            tls_acceptor_reload_rx,
348            trusted_proxy_v2_ips,
349        ))),
350        None => Ok(task::spawn(server_plaintext_loop(
351            listener,
352            app,
353            rx,
354            trusted_proxy_v2_ips,
355        ))),
356    }
357}
358
359async fn server_tls_loop(
360    mut tls_acceptor: SslAcceptor,
361    listener: TcpListener,
362    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
363    mut rx: broadcast::Receiver<CoreAction>,
364    server_message_tx: broadcast::Sender<CoreAction>,
365    mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
366    trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>,
367) {
368    pin_mut!(listener);
369
370    loop {
371        tokio::select! {
372            Ok(action) = rx.recv() => {
373                match action {
374                    CoreAction::Shutdown => break,
375                }
376            }
377            accept = listener.accept() => {
378                match accept {
379                    Ok((stream, addr)) => {
380                        let tls_acceptor = tls_acceptor.clone();
381                        let app = app.clone();
382                        task::spawn(handle_tls_conn(tls_acceptor, stream, app, addr, trusted_proxy_v2_ips.clone()));
383                    }
384                    Err(err) => {
385                        error!("Web server exited with {:?}", err);
386                        if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
387                            error!("Web server failed to send shutdown message! {:?}", err)
388                        };
389                        break;
390                    }
391                }
392            }
393            Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
394                std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
395                info!("Reloaded http tls acceptor");
396            }
397        }
398    }
399
400    info!("Stopped {}", super::TaskName::HttpsServer);
401}
402
403async fn server_plaintext_loop(
404    listener: TcpListener,
405    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
406    mut rx: broadcast::Receiver<CoreAction>,
407    trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>,
408) {
409    pin_mut!(listener);
410
411    loop {
412        tokio::select! {
413            Ok(action) = rx.recv() => {
414                match action {
415                    CoreAction::Shutdown => break,
416                }
417            }
418            accept = listener.accept() => {
419                match accept {
420                    Ok((stream, addr)) => {
421                        let app = app.clone();
422                        task::spawn(handle_conn(stream, app, addr, trusted_proxy_v2_ips.clone()));
423                    }
424                    Err(err) => {
425                        error!("Web server exited with {:?}", err);
426                        break;
427                    }
428                }
429            }
430        }
431    }
432
433    info!("Stopped {}", super::TaskName::HttpsServer);
434}
435
436/// This handles an individual connection.
437pub(crate) async fn handle_conn(
438    stream: TcpStream,
439    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
440    connection_addr: SocketAddr,
441    trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>,
442) -> Result<(), std::io::Error> {
443    let (stream, client_addr) =
444        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
445
446    let client_conn_info = ClientConnInfo {
447        connection_addr,
448        client_addr,
449        client_cert: None,
450    };
451
452    // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
453    // `TokioIo` converts between them.
454    let stream = TokioIo::new(stream);
455
456    process_client_hyper(stream, app, client_conn_info).await
457}
458
459/// This handles an individual connection.
460pub(crate) async fn handle_tls_conn(
461    acceptor: SslAcceptor,
462    stream: TcpStream,
463    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
464    connection_addr: SocketAddr,
465    trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>,
466) -> Result<(), std::io::Error> {
467    let (stream, client_addr) =
468        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
469
470    let ssl = Ssl::new(acceptor.context()).map_err(|e| {
471        error!("Failed to create TLS context: {:?}", e);
472        std::io::Error::from(ErrorKind::ConnectionAborted)
473    })?;
474
475    let mut tls_stream = SslStream::new(ssl, stream).map_err(|err| {
476        error!(?err, "Failed to create TLS stream");
477        std::io::Error::from(ErrorKind::ConnectionAborted)
478    })?;
479
480    match SslStream::accept(Pin::new(&mut tls_stream)).await {
481        Ok(_) => {
482            // Process the client cert (if any)
483            let client_cert = if let Some(peer_cert) = tls_stream.ssl().peer_certificate() {
484                // TODO: This is where we should be checking the CRL!!!
485
486                // Extract the cert from openssl to x509-cert which is a better
487                // parser to handle the various extensions.
488
489                let cert_der = peer_cert.to_der().map_err(|ossl_err| {
490                    error!(?ossl_err, "unable to process x509 certificate as DER");
491                    std::io::Error::from(ErrorKind::ConnectionAborted)
492                })?;
493
494                let certificate = Certificate::from_der(&cert_der).map_err(|ossl_err| {
495                    error!(?ossl_err, "unable to process DER certificate to x509");
496                    std::io::Error::from(ErrorKind::ConnectionAborted)
497                })?;
498
499                let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
500                    error!("subject public key bitstring is not octet aligned");
501                    std::io::Error::from(ErrorKind::ConnectionAborted)
502                })?;
503
504                Some(ClientCertInfo {
505                    public_key_s256,
506                    certificate,
507                })
508            } else {
509                None
510            };
511
512            let client_conn_info = ClientConnInfo {
513                connection_addr,
514                client_addr,
515                client_cert,
516            };
517
518            // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
519            // `TokioIo` converts between them.
520            let stream = TokioIo::new(tls_stream);
521
522            process_client_hyper(stream, app, client_conn_info).await
523        }
524        Err(error) => {
525            trace!("Failed to handle connection: {:?}", error);
526            Ok(())
527        }
528    }
529}
530
531async fn process_client_addr(
532    stream: TcpStream,
533    connection_addr: SocketAddr,
534    trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>,
535) -> Result<(TcpStream, SocketAddr), std::io::Error> {
536    let enable_proxy_v2_hdr = trusted_proxy_v2_ips
537        .map(|trusted| trusted.contains(&connection_addr.ip()))
538        .unwrap_or_default();
539
540    let (stream, client_addr) = if enable_proxy_v2_hdr {
541        match ProxyHdrV2::parse_from_read(stream).await {
542            Ok((stream, hdr)) => {
543                let remote_socket_addr = match hdr.to_remote_addr() {
544                    RemoteAddress::Local => {
545                        debug!("PROXY protocol liveness check - will not contain client data");
546                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
547                    }
548                    RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
549                    RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
550                    remote_addr => {
551                        error!(?remote_addr, "remote address in proxy header is invalid");
552                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
553                    }
554                };
555
556                (stream, remote_socket_addr)
557            }
558            Err(err) => {
559                error!(?connection_addr, ?err, "Unable to process proxy v2 header");
560                return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
561            }
562        }
563    } else {
564        (stream, connection_addr)
565    };
566
567    Ok((stream, client_addr))
568}
569
570async fn process_client_hyper<T>(
571    stream: TokioIo<T>,
572    mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
573    client_conn_info: ClientConnInfo,
574) -> Result<(), std::io::Error>
575where
576    T: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
577{
578    debug!(?client_conn_info);
579
580    let svc = tower::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
581        &mut app,
582        client_conn_info,
583    );
584
585    let svc = svc.await.map_err(|e| {
586        error!("Failed to build HTTP response: {:?}", e);
587        std::io::Error::from(ErrorKind::Other)
588    })?;
589
590    // Hyper also has its own `Service` trait and doesn't use tower. We can use
591    // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
592    // `tower::Service::call`.
593    let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
594        // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
595        // tower's `Service` requires `&mut self`.
596        //
597        // We don't need to call `poll_ready` since `Router` is always ready.
598        svc.clone().call(request)
599    });
600
601    hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
602        .serve_connection_with_upgrades(stream, hyper_service)
603        .await
604        .map_err(|e| {
605            debug!("Failed to complete connection: {:?}", e);
606            std::io::Error::from(ErrorKind::ConnectionAborted)
607        })
608}