1mod apidocs;
2pub(crate) mod cache_buster;
3pub(crate) mod errors;
4mod extractors;
5mod generic;
6mod javascript;
7mod manifest;
8pub(crate) mod middleware;
9mod oauth2;
10pub(crate) mod trace;
11mod v1;
12mod v1_domain;
13mod v1_oauth2;
14mod v1_scim;
15mod views;
16
17use self::extractors::ClientConnInfo;
18use self::javascript::*;
19use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
20use crate::config::{AddressSet, Configuration, ServerRole};
21use crate::CoreAction;
22use axum::{
23 body::Body,
24 extract::connect_info::IntoMakeServiceWithConnectInfo,
25 http::{HeaderMap, HeaderValue, Request, StatusCode},
26 middleware::{from_fn, from_fn_with_state},
27 response::{IntoResponse, Redirect, Response},
28 routing::*,
29 Router,
30};
31use axum_extra::extract::cookie::CookieJar;
32use cidr::IpCidr;
33use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
34use futures::pin_mut;
35use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
36use hyper::body::Incoming;
37use hyper_util::rt::{TokioExecutor, TokioIo};
38use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
39use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
40use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
41use serde::de::DeserializeOwned;
42use sketching::*;
43use std::fmt::Write;
44use std::io::ErrorKind;
45use std::path::PathBuf;
46use std::sync::Arc;
47use std::{
48 net::{IpAddr, SocketAddr},
49 str::FromStr,
50};
51use tokio::{
52 io::{AsyncRead, AsyncWrite},
53 net::{TcpListener, TcpStream},
54 sync::broadcast,
55 sync::mpsc,
56 task,
57};
58use tokio_rustls::TlsAcceptor;
59use tower::Service;
60use tower_http::{services::ServeDir, trace::TraceLayer};
61use url::Url;
62use uuid::Uuid;
63
64#[derive(Clone)]
65pub struct ServerState {
66 pub(crate) status_ref: &'static StatusActor,
67 pub(crate) qe_w_ref: &'static QueryServerWriteV1,
68 pub(crate) qe_r_ref: &'static QueryServerReadV1,
69 pub(crate) jws_signer: JwsHs256Signer,
71 pub(crate) trust_x_forward_for_ips: Option<Arc<AddressSet>>,
72 pub(crate) csp_header: HeaderValue,
73 pub(crate) csp_header_no_form_action: HeaderValue,
74 pub(crate) origin: Url,
75 pub(crate) domain: String,
76 pub(crate) secure_cookies: bool,
78}
79
80impl ServerState {
81 fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
85 match JwsCompact::from_str(input) {
86 Ok(val) => match self.jws_signer.verify(&val) {
87 Ok(val) => val.from_json::<T>().ok(),
88 Err(err) => {
89 error!(?err, "Failed to deserialise JWT from request");
90 if matches!(err, JwtError::InvalidSignature) {
91 warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
101 }
102 None
103 }
104 },
105 Err(_) => None,
106 }
107 }
108
109 #[instrument(level = "trace", skip_all)]
110 fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
111 headers
113 .get(KSESSIONID)
114 .and_then(|hv| {
115 trace!("trying header");
116 hv.to_str().ok()
118 })
119 .or_else(|| {
120 trace!("trying cookie");
121 jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
122 })
123 .and_then(|s| {
124 trace!(id_jws = %s);
125 self.deserialise_from_str::<Uuid>(s)
126 })
127 }
128}
129
130pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
131 let mut all_pages: Vec<JavaScriptFile> = Vec::new();
132
133 if !matches!(role, ServerRole::WriteReplicaNoUI) {
134 let pkg_path = env!("KANIDM_SERVER_UI_PKG_PATH").to_owned();
136
137 let filelist = [
138 "external/bootstrap.bundle.min.js",
139 "external/htmx.min.1.9.12.js",
140 "external/confetti.js",
141 "external/base64.js",
142 "modules/cred_update.mjs",
143 "pkhtml.js",
144 "style.js",
145 ];
146
147 for filepath in filelist {
148 match generate_integrity_hash(format!("{pkg_path}/{filepath}",)) {
149 Ok(hash) => {
150 debug!("Integrity hash for {}: {}", filepath, hash);
151 let js = JavaScriptFile { hash };
152 all_pages.push(js)
153 }
154 Err(err) => {
155 admin_error!(
156 ?err,
157 "Failed to generate integrity hash for {} - cancelling startup!",
158 filepath
159 );
160 return Err(());
161 }
162 }
163 }
164 }
165 Ok(all_pages)
166}
167
168async fn handler_404() -> Response {
169 (StatusCode::NOT_FOUND, "Route not found").into_response()
170}
171
172pub async fn create_https_server(
173 config: Configuration,
174 jws_signer: JwsHs256Signer,
175 status_ref: &'static StatusActor,
176 qe_w_ref: &'static QueryServerWriteV1,
177 qe_r_ref: &'static QueryServerReadV1,
178 server_message_tx: broadcast::Sender<CoreAction>,
179 maybe_tls_acceptor: Option<TlsAcceptor>,
180 tls_acceptor_reload_rx: mpsc::Receiver<TlsAcceptor>,
181) -> Result<task::JoinHandle<()>, ()> {
182 let rx = server_message_tx.subscribe();
183
184 let all_js_files = get_js_files(config.role)?;
185 let js_directives = all_js_files
191 .into_iter()
192 .map(|f| f.hash)
193 .collect::<Vec<String>>();
194
195 let js_checksums: String = js_directives
196 .iter()
197 .fold(String::new(), |mut output, value| {
198 let _ = write!(output, " 'sha384-{value}'");
199 output
200 });
201
202 let csp_header = format!(
203 concat!(
204 "default-src 'self'; ",
205 "base-uri 'self' https:; ",
206 "form-action 'self'; ",
207 "frame-ancestors 'none'; ",
208 "img-src 'self' data:; ",
209 "worker-src 'none'; ",
210 "script-src 'self' 'unsafe-eval'{};",
211 ),
212 js_checksums
213 );
214
215 let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
216 error!(?err, "Unable to generate content security policy");
217 })?;
218
219 let csp_header_no_form_action = format!(
227 concat!(
228 "default-src 'self'; ",
229 "base-uri 'self' https:; ",
230 "frame-ancestors 'none'; ",
231 "img-src 'self' data:; ",
232 "worker-src 'none'; ",
233 "script-src 'self' 'unsafe-eval'{};",
234 ),
235 js_checksums
236 );
237
238 let csp_header_no_form_action =
239 HeaderValue::from_str(&csp_header_no_form_action).map_err(|err| {
240 error!(
241 ?err,
242 "Unable to generate content security policy with no form action"
243 );
244 })?;
245
246 let trust_x_forward_for_ips = config
247 .http_client_address_info
248 .trusted_x_forward_for()
249 .map(Arc::new);
250
251 let trusted_proxy_v2_ips = config
252 .http_client_address_info
253 .trusted_proxy_v2()
254 .map(Arc::new);
255
256 let state = ServerState {
257 status_ref,
258 qe_w_ref,
259 qe_r_ref,
260 jws_signer,
261 trust_x_forward_for_ips,
262 csp_header,
263 csp_header_no_form_action,
264 origin: config.origin,
265 domain: config.domain.clone(),
266 secure_cookies: config.integration_test_config.is_none(),
267 };
268
269 let static_routes = match config.role {
270 ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
271 Router::new()
272 .route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
273 .route("/ui/images/domain", get(v1_domain::image_get))
274 .route("/manifest.webmanifest", get(manifest::manifest)) .layer(middleware::compression::new())
278 .layer(from_fn(middleware::caching::cache_me_short))
279 .route("/", get(|| async { Redirect::to("/ui") }))
280 .nest("/ui", views::view_router(state.clone()))
281 }
283 ServerRole::WriteReplicaNoUI => Router::new(),
284 };
285 let app = Router::new()
286 .merge(oauth2::route_setup(state.clone()))
287 .merge(v1_scim::route_setup())
288 .merge(v1::route_setup(state.clone()))
289 .route("/robots.txt", get(generic::robots_txt))
290 .route(
291 views::constants::Urls::WellKnownChangePassword.as_ref(),
292 get(generic::redirect_to_update_credentials),
293 );
294
295 let app = match config.role {
296 ServerRole::WriteReplicaNoUI => app,
297 ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
298 let pkg_path = PathBuf::from(env!("KANIDM_SERVER_UI_PKG_PATH"));
299 if !pkg_path.exists() {
300 eprintln!(
301 "Couldn't find htmx UI package path: ({}), quitting.",
302 env!("KANIDM_SERVER_UI_PKG_PATH")
303 );
304 std::process::exit(1);
305 }
306 let pkg_router = Router::new()
307 .nest_service("/pkg", ServeDir::new(pkg_path))
308 .layer(from_fn(middleware::caching::cache_me_short));
310
311 app.merge(pkg_router)
312 }
313 };
314
315 let trace_layer = TraceLayer::new_for_http()
317 .make_span_with(trace::DefaultMakeSpanKanidmd::new())
318 .on_response(trace::DefaultOnResponseKanidmd::new());
320
321 let app = app
322 .merge(static_routes)
323 .layer(from_fn_with_state(
324 state.clone(),
325 middleware::security_headers::security_headers_layer,
326 ))
327 .layer(from_fn(middleware::version_middleware))
328 .layer(from_fn(
329 middleware::hsts_header::strict_transport_security_layer,
330 ));
331
332 #[cfg(any(test, debug_assertions))]
334 let app = app.layer(from_fn(middleware::are_we_json_yet));
335
336 let app = app
337 .route("/status", get(generic::status))
338 .fallback(handler_404)
340 .layer(from_fn_with_state(
345 state.clone(),
346 middleware::ip_address_middleware,
347 ))
348 .layer(from_fn(middleware::kopid_middleware))
349 .merge(apidocs::router())
350 .layer(trace_layer)
352 .with_state(state)
353 .into_make_service_with_connect_info::<ClientConnInfo>();
355
356 let addr = SocketAddr::from_str(&config.address).map_err(|err| {
357 error!(
358 "Failed to parse address ({:?}) from config: {:?}",
359 config.address, err
360 );
361 })?;
362
363 info!("Starting the web server...");
364
365 let listener = match TcpListener::bind(addr).await {
366 Ok(l) => l,
367 Err(err) => {
368 error!(?err, "Failed to bind tcp listener");
369 return Err(());
370 }
371 };
372
373 match maybe_tls_acceptor {
374 Some(tls_acceptor) => Ok(task::spawn(server_tls_loop(
375 tls_acceptor,
376 listener,
377 app,
378 rx,
379 server_message_tx,
380 tls_acceptor_reload_rx,
381 trusted_proxy_v2_ips,
382 ))),
383 None => Ok(task::spawn(server_plaintext_loop(
384 listener,
385 app,
386 rx,
387 trusted_proxy_v2_ips,
388 ))),
389 }
390}
391
392async fn server_tls_loop(
393 mut tls_acceptor: TlsAcceptor,
394 listener: TcpListener,
395 app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
396 mut rx: broadcast::Receiver<CoreAction>,
397 server_message_tx: broadcast::Sender<CoreAction>,
398 mut tls_acceptor_reload_rx: mpsc::Receiver<TlsAcceptor>,
399 trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
400) {
401 pin_mut!(listener);
402
403 loop {
404 tokio::select! {
405 Ok(action) = rx.recv() => {
406 match action {
407 CoreAction::Shutdown => break,
408 }
409 }
410 accept = listener.accept() => {
411 match accept {
412 Ok((stream, addr)) => {
413 let tls_acceptor = tls_acceptor.clone();
414 let app = app.clone();
415 task::spawn(handle_tls_conn(tls_acceptor, stream, app, addr, trusted_proxy_v2_ips.clone()));
416 }
417 Err(err) => {
418 error!("Web server exited with {:?}", err);
419 if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
420 error!("Web server failed to send shutdown message! {:?}", err)
421 };
422 break;
423 }
424 }
425 }
426 Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
427 std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
428 info!("Reloaded http tls acceptor");
429 }
430 }
431 }
432
433 info!("Stopped {}", super::TaskName::HttpsServer);
434}
435
436async fn server_plaintext_loop(
437 listener: TcpListener,
438 app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
439 mut rx: broadcast::Receiver<CoreAction>,
440 trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
441) {
442 pin_mut!(listener);
443
444 loop {
445 tokio::select! {
446 Ok(action) = rx.recv() => {
447 match action {
448 CoreAction::Shutdown => break,
449 }
450 }
451 accept = listener.accept() => {
452 match accept {
453 Ok((stream, addr)) => {
454 let app = app.clone();
455 task::spawn(handle_conn(stream, app, addr, trusted_proxy_v2_ips.clone()));
456 }
457 Err(err) => {
458 error!("Web server exited with {:?}", err);
459 break;
460 }
461 }
462 }
463 }
464 }
465
466 info!("Stopped {}", super::TaskName::HttpsServer);
467}
468
469pub(crate) async fn handle_conn(
471 stream: TcpStream,
472 app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
473 connection_addr: SocketAddr,
474 trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
475) -> Result<(), std::io::Error> {
476 let (stream, client_ip_addr) =
477 process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
478
479 let client_conn_info = ClientConnInfo {
480 connection_addr,
481 client_ip_addr,
482 client_cert: None,
483 };
484
485 let stream = TokioIo::new(stream);
488
489 process_client_hyper(stream, app, client_conn_info).await
490}
491
492pub(crate) async fn handle_tls_conn(
494 acceptor: TlsAcceptor,
495 stream: TcpStream,
496 app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
497 connection_addr: SocketAddr,
498 trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
499) -> Result<(), std::io::Error> {
500 let (stream, client_ip_addr) =
501 process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?;
502
503 let tls_stream = acceptor.accept(stream).await.map_err(|err| {
504 error!(?err, "Failed to create TLS stream");
505 std::io::Error::from(ErrorKind::ConnectionAborted)
506 })?;
507
508 let maybe_peer_cert = tls_stream
509 .get_ref()
510 .1
511 .peer_certificates()
512 .and_then(|peer_certs| peer_certs.first());
514
515 let client_cert = if let Some(peer_cert) = maybe_peer_cert {
517 let certificate = Certificate::from_der(peer_cert).map_err(|ossl_err| {
523 error!(?ossl_err, "unable to process DER certificate to x509");
524 std::io::Error::from(ErrorKind::ConnectionAborted)
525 })?;
526
527 let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
528 error!("subject public key bitstring is not octet aligned");
529 std::io::Error::from(ErrorKind::ConnectionAborted)
530 })?;
531
532 Some(ClientCertInfo {
533 public_key_s256,
534 certificate,
535 })
536 } else {
537 None
538 };
539
540 let client_conn_info = ClientConnInfo {
541 connection_addr,
542 client_ip_addr,
543 client_cert,
544 };
545
546 let stream = TokioIo::new(tls_stream);
549
550 process_client_hyper(stream, app, client_conn_info).await
551}
552
553async fn process_client_addr(
554 stream: TcpStream,
555 connection_addr: SocketAddr,
556 trusted_proxy_v2_ips: Option<Arc<Vec<IpCidr>>>,
557) -> Result<(TcpStream, IpAddr), std::io::Error> {
558 let enable_proxy_v2_hdr = trusted_proxy_v2_ips
559 .map(|trusted| {
560 trusted
561 .iter()
562 .any(|ip_cidr| ip_cidr.contains(&connection_addr.ip().to_canonical()))
563 })
564 .unwrap_or_default();
565
566 let (stream, client_addr) = if enable_proxy_v2_hdr {
567 match ProxyHdrV2::parse_from_read(stream).await {
568 Ok((stream, hdr)) => {
569 let remote_socket_addr = match hdr.to_remote_addr() {
570 RemoteAddress::Local => {
571 debug!("PROXY protocol liveness check - will not contain client data");
572 connection_addr
574 }
575 RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
576 RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
577 remote_addr => {
578 error!(?remote_addr, "remote address in proxy header is invalid");
579 return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
580 }
581 };
582
583 (stream, remote_socket_addr)
584 }
585 Err(err) => {
586 error!(?connection_addr, ?err, "Unable to process proxy v2 header");
587 return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
588 }
589 }
590 } else {
591 (stream, connection_addr)
592 };
593
594 Ok((stream, client_addr.ip()))
595}
596
597async fn process_client_hyper<T>(
598 stream: TokioIo<T>,
599 mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
600 client_conn_info: ClientConnInfo,
601) -> Result<(), std::io::Error>
602where
603 T: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static,
604{
605 debug!(?client_conn_info);
606
607 let svc = tower::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
608 &mut app,
609 client_conn_info,
610 );
611
612 let svc = svc.await.map_err(|e| {
613 error!("Failed to build HTTP response: {:?}", e);
614 std::io::Error::from(ErrorKind::Other)
615 })?;
616
617 let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
621 svc.clone().call(request)
626 });
627
628 hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
629 .serve_connection_with_upgrades(stream, hyper_service)
630 .await
631 .map_err(|e| {
632 debug!("Failed to complete connection: {:?}", e);
633 std::io::Error::from(ErrorKind::ConnectionAborted)
634 })
635}