1mod apidocs;
2pub(crate) mod cache_buster;
3pub(crate) mod errors;
4mod extractors;
5mod generic;
6mod javascript;
7mod manifest;
8pub(crate) mod middleware;
9mod oauth2;
10pub(crate) mod trace;
11mod v1;
12mod v1_domain;
13mod v1_oauth2;
14mod v1_scim;
15mod views;
16
17use self::extractors::ClientConnInfo;
18use self::javascript::*;
19use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
20use crate::config::{AddressSet, Configuration, ServerRole};
21use crate::CoreAction;
22use axum::{
23    body::Body,
24    extract::connect_info::IntoMakeServiceWithConnectInfo,
25    http::{HeaderMap, HeaderValue, Request, StatusCode},
26    middleware::{from_fn, from_fn_with_state},
27    response::{IntoResponse, Redirect, Response},
28    routing::*,
29    Router,
30};
31use axum_extra::extract::cookie::CookieJar;
32use cidr::IpCidr;
33use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
34use futures::pin_mut;
35use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
36use hyper::body::Incoming;
37use hyper_util::rt::{TokioExecutor, TokioIo};
38use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
39use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
40use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
41use serde::de::DeserializeOwned;
42use sketching::*;
43use std::fmt::Write;
44use std::io::ErrorKind;
45use std::path::PathBuf;
46use std::sync::Arc;
47use std::{
48    net::{IpAddr, SocketAddr},
49    str::FromStr,
50};
51use tokio::{
52    io::{AsyncRead, AsyncWrite},
53    net::{TcpListener, TcpStream},
54    sync::broadcast,
55    sync::mpsc,
56    task,
57};
58use tokio_rustls::TlsAcceptor;
59use tower::Service;
60use tower_http::{services::ServeDir, trace::TraceLayer};
61use url::Url;
62use uuid::Uuid;
63
64#[derive(Clone)]
65pub struct ServerState {
66    pub(crate) status_ref: &'static StatusActor,
67    pub(crate) qe_w_ref: &'static QueryServerWriteV1,
68    pub(crate) qe_r_ref: &'static QueryServerReadV1,
69    pub(crate) jws_signer: JwsHs256Signer,
71    pub(crate) trust_x_forward_for_ips: Option<Arc<AddressSet>>,
72    pub(crate) csp_header: HeaderValue,
73    pub(crate) origin: Url,
74    pub(crate) domain: String,
75    pub(crate) secure_cookies: bool,
77}
78
79impl ServerState {
80    fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
84        match JwsCompact::from_str(input) {
85            Ok(val) => match self.jws_signer.verify(&val) {
86                Ok(val) => val.from_json::<T>().ok(),
87                Err(err) => {
88                    error!(?err, "Failed to deserialise JWT from request");
89                    if matches!(err, JwtError::InvalidSignature) {
90                        warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
100                    }
101                    None
102                }
103            },
104            Err(_) => None,
105        }
106    }
107
108    #[instrument(level = "trace", skip_all)]
109    fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
110        headers
112            .get(KSESSIONID)
113            .and_then(|hv| {
114                trace!("trying header");
115                hv.to_str().ok()
117            })
118            .or_else(|| {
119                trace!("trying cookie");
120                jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
121            })
122            .and_then(|s| {
123                trace!(id_jws = %s);
124                self.deserialise_from_str::<Uuid>(s)
125            })
126    }
127}
128
129pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
130    let mut all_pages: Vec<JavaScriptFile> = Vec::new();
131
132    if !matches!(role, ServerRole::WriteReplicaNoUI) {
133        let pkg_path = env!("KANIDM_SERVER_UI_PKG_PATH").to_owned();
135
136        let filelist = [
137            "external/bootstrap.bundle.min.js",
138            "external/htmx.min.1.9.12.js",
139            "external/confetti.js",
140            "external/base64.js",
141            "modules/cred_update.mjs",
142            "pkhtml.js",
143            "style.js",
144        ];
145
146        for filepath in filelist {
147            match generate_integrity_hash(format!("{pkg_path}/{filepath}",)) {
148                Ok(hash) => {
149                    debug!("Integrity hash for {}: {}", filepath, hash);
150                    let js = JavaScriptFile { hash };
151                    all_pages.push(js)
152                }
153                Err(err) => {
154                    admin_error!(
155                        ?err,
156                        "Failed to generate integrity hash for {} - cancelling startup!",
157                        filepath
158                    );
159                    return Err(());
160                }
161            }
162        }
163    }
164    Ok(all_pages)
165}
166
167async fn handler_404() -> Response {
168    (StatusCode::NOT_FOUND, "Route not found").into_response()
169}
170
171pub async fn create_https_server(
172    config: Configuration,
173    jws_signer: JwsHs256Signer,
174    status_ref: &'static StatusActor,
175    qe_w_ref: &'static QueryServerWriteV1,
176    qe_r_ref: &'static QueryServerReadV1,
177    server_message_tx: broadcast::Sender<CoreAction>,
178    maybe_tls_acceptor: Option<TlsAcceptor>,
179    tls_acceptor_reload_rx: mpsc::Receiver<TlsAcceptor>,
180) -> Result<task::JoinHandle<()>, ()> {
181    let rx = server_message_tx.subscribe();
182
183    let all_js_files = get_js_files(config.role)?;
184    let js_directives = all_js_files
190        .into_iter()
191        .map(|f| f.hash)
192        .collect::<Vec<String>>();
193
194    let js_checksums: String = js_directives
195        .iter()
196        .fold(String::new(), |mut output, value| {
197            let _ = write!(output, " 'sha384-{value}'");
198            output
199        });
200
201    let csp_header = format!(
202        concat!(
203            "default-src 'self'; ",
204            "base-uri 'self' https:; ",
205            "form-action 'self' https:;",
206            "frame-ancestors 'none'; ",
207            "img-src 'self' data:; ",
208            "worker-src 'none'; ",
209            "script-src 'self' 'unsafe-eval'{};",
210        ),
211        js_checksums
212    );
213
214    let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
215        error!(?err, "Unable to generate content security policy");
216    })?;
217
218    let trust_x_forward_for_ips = config
219        .http_client_address_info
220        .trusted_x_forward_for()
221        .map(Arc::new);
222
223    let trusted_proxy_v2_ips = config
224        .http_client_address_info
225        .trusted_proxy_v2()
226        .map(Arc::new);
227
228    let state = ServerState {
229        status_ref,
230        qe_w_ref,
231        qe_r_ref,
232        jws_signer,
233        trust_x_forward_for_ips,
234        csp_header,
235        origin: config.origin,
236        domain: config.domain.clone(),
237        secure_cookies: config.integration_test_config.is_none(),
238    };
239
240    let static_routes = match config.role {
241        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
242            Router::new()
243                .route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
244                .route("/ui/images/domain", get(v1_domain::image_get))
245                .route("/manifest.webmanifest", get(manifest::manifest)) .layer(middleware::compression::new())
249                .layer(from_fn(middleware::caching::cache_me_short))
250                .route("/", get(|| async { Redirect::to("/ui") }))
251                .nest("/ui", views::view_router())
252            }
254        ServerRole::WriteReplicaNoUI => Router::new(),
255    };
256    let app = Router::new()
257        .merge(oauth2::route_setup(state.clone()))
258        .merge(v1_scim::route_setup())
259        .merge(v1::route_setup(state.clone()))
260        .route("/robots.txt", get(generic::robots_txt))
261        .route(
262            views::constants::Urls::WellKnownChangePassword.as_ref(),
263            get(generic::redirect_to_update_credentials),
264        );
265
266    let app = match config.role {
267        ServerRole::WriteReplicaNoUI => app,
268        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
269            let pkg_path = PathBuf::from(env!("KANIDM_SERVER_UI_PKG_PATH"));
270            if !pkg_path.exists() {
271                eprintln!(
272                    "Couldn't find htmx UI package path: ({}), quitting.",
273                    env!("KANIDM_SERVER_UI_PKG_PATH")
274                );
275                std::process::exit(1);
276            }
277            let pkg_router = Router::new()
278                .nest_service("/pkg", ServeDir::new(pkg_path))
279                .layer(from_fn(middleware::caching::cache_me_short));
281
282            app.merge(pkg_router)
283        }
284    };
285
286    let trace_layer = TraceLayer::new_for_http()
288        .make_span_with(trace::DefaultMakeSpanKanidmd::new())
289        .on_response(trace::DefaultOnResponseKanidmd::new());
291
292    let app = app
293        .merge(static_routes)
294        .layer(from_fn_with_state(
295            state.clone(),
296            middleware::security_headers::security_headers_layer,
297        ))
298        .layer(from_fn(middleware::version_middleware))
299        .layer(from_fn(
300            middleware::hsts_header::strict_transport_security_layer,
301        ));
302
303    #[cfg(any(test, debug_assertions))]
305    let app = app.layer(from_fn(middleware::are_we_json_yet));
306
307    let app = app
308        .route("/status", get(generic::status))
309        .fallback(handler_404)
311        .layer(from_fn_with_state(
316            state.clone(),
317            middleware::ip_address_middleware,
318        ))
319        .layer(from_fn(middleware::kopid_middleware))
320        .merge(apidocs::router())
321        .layer(trace_layer)
323        .with_state(state)
324        .into_make_service_with_connect_info::<ClientConnInfo>();
326
327    let addr = SocketAddr::from_str(&config.address).map_err(|err| {
328        error!(
329            "Failed to parse address ({:?}) from config: {:?}",
330            config.address, err
331        );
332    })?;
333
334    info!("Starting the web server...");
335
336    let listener = match TcpListener::bind(addr).await {
337        Ok(l) => l,
338        Err(err) => {
339            error!(?err, "Failed to bind tcp listener");
340            return Err(());
341        }
342    };
343
344    match maybe_tls_acceptor {
345        Some(tls_acceptor) => Ok(task::spawn(server_tls_loop(
346            tls_acceptor,
347            listener,
348            app,
349            rx,
350            server_message_tx,
351            tls_acceptor_reload_rx,
352            trusted_proxy_v2_ips,
353        ))),
354        None => Ok(task::spawn(server_plaintext_loop(
355            listener,
356            app,
357            rx,
358            trusted_proxy_v2_ips,
359        ))),
360    }
361}
362
363async fn server_tls_loop(
364    mut tls_acceptor: TlsAcceptor,
365    listener: TcpListener,
366    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
367    mut rx: broadcast::Receiver<CoreAction>,
368    server_message_tx: broadcast::Sender<CoreAction>,
369    mut tls_acceptor_reload_rx: mpsc::Receiver<TlsAcceptor>,
370    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
371) {
372    pin_mut!(listener);
373
374    loop {
375        tokio::select! {
376            Ok(action) = rx.recv() => {
377                match action {
378                    CoreAction::Shutdown => break,
379                }
380            }
381            accept = listener.accept() => {
382                match accept {
383                    Ok((stream, addr)) => {
384                        let tls_acceptor = tls_acceptor.clone();
385                        let app = app.clone();
386                        task::spawn(handle_tls_conn(tls_acceptor, stream, app, addr, trusted_proxy_v2_ips.clone()));
387                    }
388                    Err(err) => {
389                        error!("Web server exited with {:?}", err);
390                        if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
391                            error!("Web server failed to send shutdown message! {:?}", err)
392                        };
393                        break;
394                    }
395                }
396            }
397            Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
398                std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
399                info!("Reloaded http tls acceptor");
400            }
401        }
402    }
403
404    info!("Stopped {}", super::TaskName::HttpsServer);
405}
406
407async fn server_plaintext_loop(
408    listener: TcpListener,
409    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
410    mut rx: broadcast::Receiver<CoreAction>,
411    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
412) {
413    pin_mut!(listener);
414
415    loop {
416        tokio::select! {
417            Ok(action) = rx.recv() => {
418                match action {
419                    CoreAction::Shutdown => break,
420                }
421            }
422            accept = listener.accept() => {
423                match accept {
424                    Ok((stream, addr)) => {
425                        let app = app.clone();
426                        task::spawn(handle_conn(stream, app, addr, trusted_proxy_v2_ips.clone()));
427                    }
428                    Err(err) => {
429                        error!("Web server exited with {:?}", err);
430                        break;
431                    }
432                }
433            }
434        }
435    }
436
437    info!("Stopped {}", super::TaskName::HttpsServer);
438}
439
440pub(crate) async fn handle_conn(
442    stream: TcpStream,
443    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
444    connection_addr: SocketAddr,
445    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
446) -> Result<(), std::io::Error> {
447    let (stream, client_ip_addr) =
448        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
449
450    let client_conn_info = ClientConnInfo {
451        connection_addr,
452        client_ip_addr,
453        client_cert: None,
454    };
455
456    let stream = TokioIo::new(stream);
459
460    process_client_hyper(stream, app, client_conn_info).await
461}
462
463pub(crate) async fn handle_tls_conn(
465    acceptor: TlsAcceptor,
466    stream: TcpStream,
467    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
468    connection_addr: SocketAddr,
469    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
470) -> Result<(), std::io::Error> {
471    let (stream, client_ip_addr) =
472        process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
473
474    let tls_stream = acceptor.accept(stream).await.map_err(|err| {
475        error!(?err, "Failed to create TLS stream");
476        std::io::Error::from(ErrorKind::ConnectionAborted)
477    })?;
478
479    let maybe_peer_cert = tls_stream
480        .get_ref()
481        .1
482        .peer_certificates()
483        .and_then(|peer_certs| peer_certs.first());
485
486    let client_cert = if let Some(peer_cert) = maybe_peer_cert {
488        let certificate = Certificate::from_der(peer_cert).map_err(|ossl_err| {
494            error!(?ossl_err, "unable to process DER certificate to x509");
495            std::io::Error::from(ErrorKind::ConnectionAborted)
496        })?;
497
498        let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
499            error!("subject public key bitstring is not octet aligned");
500            std::io::Error::from(ErrorKind::ConnectionAborted)
501        })?;
502
503        Some(ClientCertInfo {
504            public_key_s256,
505            certificate,
506        })
507    } else {
508        None
509    };
510
511    let client_conn_info = ClientConnInfo {
512        connection_addr,
513        client_ip_addr,
514        client_cert,
515    };
516
517    let stream = TokioIo::new(tls_stream);
520
521    process_client_hyper(stream, app, client_conn_info).await
522}
523
524async fn process_client_addr(
525    stream: TcpStream,
526    connection_addr: SocketAddr,
527    trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
528) -> Result<(TcpStream, IpAddr), std::io::Error> {
529    let enable_proxy_v2_hdr = trusted_proxy_v2_ips
530        .map(|trusted| {
531            trusted
532                .iter()
533                .any(|ip_cidr| ip_cidr.contains(&connection_addr.ip()))
534        })
535        .unwrap_or_default();
536
537    let (stream, client_addr) = if enable_proxy_v2_hdr {
538        match ProxyHdrV2::parse_from_read(stream).await {
539            Ok((stream, hdr)) => {
540                let remote_socket_addr = match hdr.to_remote_addr() {
541                    RemoteAddress::Local => {
542                        debug!("PROXY protocol liveness check - will not contain client data");
543                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
544                    }
545                    RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
546                    RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
547                    remote_addr => {
548                        error!(?remote_addr, "remote address in proxy header is invalid");
549                        return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
550                    }
551                };
552
553                (stream, remote_socket_addr)
554            }
555            Err(err) => {
556                error!(?connection_addr, ?err, "Unable to process proxy v2 header");
557                return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
558            }
559        }
560    } else {
561        (stream, connection_addr)
562    };
563
564    Ok((stream, client_addr.ip()))
565}
566
567async fn process_client_hyper<T>(
568    stream: TokioIo<T>,
569    mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
570    client_conn_info: ClientConnInfo,
571) -> Result<(), std::io::Error>
572where
573    T: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
574{
575    debug!(?client_conn_info);
576
577    let svc = tower::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
578        &mut app,
579        client_conn_info,
580    );
581
582    let svc = svc.await.map_err(|e| {
583        error!("Failed to build HTTP response: {:?}", e);
584        std::io::Error::from(ErrorKind::Other)
585    })?;
586
587    let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
591        svc.clone().call(request)
596    });
597
598    hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
599        .serve_connection_with_upgrades(stream, hyper_service)
600        .await
601        .map_err(|e| {
602            debug!("Failed to complete connection: {:?}", e);
603            std::io::Error::from(ErrorKind::ConnectionAborted)
604        })
605}