mod apidocs;
pub(crate) mod cache_buster;
pub(crate) mod errors;
mod extractors;
mod generic;
mod javascript;
mod manifest;
pub(crate) mod middleware;
mod oauth2;
pub(crate) mod trace;
mod v1;
mod v1_domain;
mod v1_oauth2;
mod v1_scim;
mod views;
use self::extractors::ClientConnInfo;
use self::javascript::*;
use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
use crate::config::{Configuration, ServerRole};
use crate::CoreAction;
use axum::{
body::Body,
extract::connect_info::IntoMakeServiceWithConnectInfo,
http::{HeaderMap, HeaderValue, Request},
middleware::{from_fn, from_fn_with_state},
response::Redirect,
routing::*,
Router,
};
use axum_extra::extract::cookie::CookieJar;
use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
use futures::pin_mut;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
use openssl::ssl::{Ssl, SslAcceptor};
use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
use serde::de::DeserializeOwned;
use sketching::*;
use std::fmt::Write;
use tokio::{
net::{TcpListener, TcpStream},
sync::broadcast,
sync::mpsc,
task,
};
use tokio_openssl::SslStream;
use tower::Service;
use tower_http::{services::ServeDir, trace::TraceLayer};
use url::Url;
use uuid::Uuid;
use std::io::ErrorKind;
use std::path::PathBuf;
use std::pin::Pin;
use std::{net::SocketAddr, str::FromStr};
#[derive(Clone)]
pub struct ServerState {
pub(crate) status_ref: &'static StatusActor,
pub(crate) qe_w_ref: &'static QueryServerWriteV1,
pub(crate) qe_r_ref: &'static QueryServerReadV1,
pub(crate) jws_signer: JwsHs256Signer,
pub(crate) trust_x_forward_for: bool,
pub(crate) csp_header: HeaderValue,
pub(crate) origin: Url,
pub(crate) domain: String,
pub(crate) secure_cookies: bool,
}
impl ServerState {
fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
match JwsCompact::from_str(input) {
Ok(val) => match self.jws_signer.verify(&val) {
Ok(val) => val.from_json::<T>().ok(),
Err(err) => {
error!(?err, "Failed to deserialise JWT from request");
if matches!(err, JwtError::InvalidSignature) {
warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
}
None
}
},
Err(_) => None,
}
}
#[instrument(level = "trace", skip_all)]
fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
headers
.get(KSESSIONID)
.and_then(|hv| {
trace!("trying header");
hv.to_str().ok()
})
.or_else(|| {
trace!("trying cookie");
jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
})
.and_then(|s| {
trace!(id_jws = %s);
self.deserialise_from_str::<Uuid>(s)
})
}
}
pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
let mut all_pages: Vec<JavaScriptFile> = Vec::new();
if !matches!(role, ServerRole::WriteReplicaNoUI) {
let pkg_path = env!("KANIDM_HTMX_UI_PKG_PATH").to_owned();
let filelist = [
"external/bootstrap.bundle.min.js",
"external/htmx.min.1.9.12.js",
"external/confetti.js",
"external/base64.js",
"modules/cred_update.mjs",
"pkhtml.js",
];
for filepath in filelist {
match generate_integrity_hash(format!("{}/{}", pkg_path, filepath,)) {
Ok(hash) => {
let js = JavaScriptFile { hash };
all_pages.push(js)
}
Err(err) => {
admin_error!(
?err,
"Failed to generate integrity hash for {} - cancelling startup!",
filepath
);
return Err(());
}
}
}
}
Ok(all_pages)
}
pub async fn create_https_server(
config: Configuration,
jws_signer: JwsHs256Signer,
status_ref: &'static StatusActor,
qe_w_ref: &'static QueryServerWriteV1,
qe_r_ref: &'static QueryServerReadV1,
server_message_tx: broadcast::Sender<CoreAction>,
maybe_tls_acceptor: Option<SslAcceptor>,
tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
) -> Result<task::JoinHandle<()>, ()> {
let rx = server_message_tx.subscribe();
let all_js_files = get_js_files(config.role)?;
let js_directives = all_js_files
.into_iter()
.map(|f| f.hash)
.collect::<Vec<String>>();
let js_checksums: String = js_directives
.iter()
.fold(String::new(), |mut output, value| {
let _ = write!(output, " 'sha384-{}'", value);
output
});
let csp_header = format!(
concat!(
"default-src 'self'; ",
"base-uri 'self' https:; ",
"form-action 'self' https:;",
"frame-ancestors 'none'; ",
"img-src 'self' data:; ",
"worker-src 'none'; ",
"script-src 'self' 'unsafe-eval'{};",
),
js_checksums
);
let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
error!(?err, "Unable to generate content security policy");
})?;
let trust_x_forward_for = config.trust_x_forward_for;
let origin = Url::parse(&config.origin)
.map_err(|err| {
error!(?err, "Unable to parse origin URL - refusing to start. You must correct the value for origin. {:?}", config.origin);
})?;
let state = ServerState {
status_ref,
qe_w_ref,
qe_r_ref,
jws_signer,
trust_x_forward_for,
csp_header,
origin,
domain: config.domain.clone(),
secure_cookies: config.integration_test_config.is_none(),
};
let static_routes = match config.role {
ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
Router::new()
.route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
.route("/ui/images/domain", get(v1_domain::image_get))
.route("/manifest.webmanifest", get(manifest::manifest)) .layer(middleware::compression::new())
.layer(from_fn(middleware::caching::cache_me_short))
.route("/", get(|| async { Redirect::to("/ui") }))
.nest("/ui", views::view_router())
}
ServerRole::WriteReplicaNoUI => Router::new(),
};
let app = Router::new()
.merge(oauth2::route_setup(state.clone()))
.merge(v1_scim::route_setup())
.merge(v1::route_setup(state.clone()))
.route("/robots.txt", get(generic::robots_txt));
let app = match config.role {
ServerRole::WriteReplicaNoUI => app,
ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
let pkg_path = PathBuf::from(env!("KANIDM_HTMX_UI_PKG_PATH"));
if !pkg_path.exists() {
eprintln!(
"Couldn't find htmx UI package path: ({}), quitting.",
env!("KANIDM_HTMX_UI_PKG_PATH")
);
std::process::exit(1);
}
let pkg_router = Router::new()
.nest_service("/pkg", ServeDir::new(pkg_path))
.layer(from_fn(middleware::caching::cache_me_short));
app.merge(pkg_router)
}
};
let trace_layer = TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpanKanidmd::new())
.on_response(trace::DefaultOnResponseKanidmd::new());
let app = app
.merge(static_routes)
.layer(from_fn_with_state(
state.clone(),
middleware::security_headers::security_headers_layer,
))
.layer(from_fn(middleware::version_middleware))
.layer(from_fn(
middleware::hsts_header::strict_transport_security_layer,
));
#[cfg(any(test, debug_assertions))]
let app = app.layer(from_fn(middleware::are_we_json_yet));
let app = app
.route("/status", get(generic::status))
.layer(from_fn(middleware::kopid_middleware))
.merge(apidocs::router())
.layer(trace_layer)
.with_state(state)
.into_make_service_with_connect_info::<ClientConnInfo>();
let addr = SocketAddr::from_str(&config.address).map_err(|err| {
error!(
"Failed to parse address ({:?}) from config: {:?}",
config.address, err
);
})?;
info!("Starting the web server...");
match maybe_tls_acceptor {
Some(tls_acceptor) => {
let listener = match TcpListener::bind(addr).await {
Ok(l) => l,
Err(err) => {
error!(?err, "Failed to bind tcp listener");
return Err(());
}
};
Ok(task::spawn(server_loop(
tls_acceptor,
listener,
app,
rx,
server_message_tx,
tls_acceptor_reload_rx,
)))
}
None => Ok(task::spawn(server_loop_plaintext(addr, app, rx))),
}
}
async fn server_loop(
mut tls_acceptor: SslAcceptor,
listener: TcpListener,
app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
mut rx: broadcast::Receiver<CoreAction>,
server_message_tx: broadcast::Sender<CoreAction>,
mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
) {
pin_mut!(listener);
loop {
tokio::select! {
Ok(action) = rx.recv() => {
match action {
CoreAction::Shutdown => break,
}
}
accept = listener.accept() => {
match accept {
Ok((stream, addr)) => {
let tls_acceptor = tls_acceptor.clone();
let app = app.clone();
task::spawn(handle_conn(tls_acceptor, stream, app, addr));
}
Err(err) => {
error!("Web server exited with {:?}", err);
if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
error!("Web server failed to send shutdown message! {:?}", err)
};
break;
}
}
}
Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
info!("Reloaded http tls acceptor");
}
}
}
info!("Stopped {}", super::TaskName::HttpsServer);
}
async fn server_loop_plaintext(
addr: SocketAddr,
app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
mut rx: broadcast::Receiver<CoreAction>,
) {
let listener = axum_server::bind(addr).serve(app);
pin_mut!(listener);
loop {
tokio::select! {
Ok(action) = rx.recv() => {
match action {
CoreAction::Shutdown =>
break,
}
}
_ = &mut listener => {}
}
}
info!("Stopped {}", super::TaskName::HttpsServer);
}
pub(crate) async fn handle_conn(
acceptor: SslAcceptor,
stream: TcpStream,
mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
addr: SocketAddr,
) -> Result<(), std::io::Error> {
let ssl = Ssl::new(acceptor.context()).map_err(|e| {
error!("Failed to create TLS context: {:?}", e);
std::io::Error::from(ErrorKind::ConnectionAborted)
})?;
let mut tls_stream = SslStream::new(ssl, stream).map_err(|err| {
error!(?err, "Failed to create TLS stream");
std::io::Error::from(ErrorKind::ConnectionAborted)
})?;
match SslStream::accept(Pin::new(&mut tls_stream)).await {
Ok(_) => {
let client_cert = if let Some(peer_cert) = tls_stream.ssl().peer_certificate() {
let cert_der = peer_cert.to_der().map_err(|ossl_err| {
error!(?ossl_err, "unable to process x509 certificate as DER");
std::io::Error::from(ErrorKind::ConnectionAborted)
})?;
let certificate = Certificate::from_der(&cert_der).map_err(|ossl_err| {
error!(?ossl_err, "unable to process DER certificate to x509");
std::io::Error::from(ErrorKind::ConnectionAborted)
})?;
let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
error!("subject public key bitstring is not octet aligned");
std::io::Error::from(ErrorKind::ConnectionAborted)
})?;
Some(ClientCertInfo {
public_key_s256,
certificate,
})
} else {
None
};
let client_conn_info = ClientConnInfo { addr, client_cert };
debug!(?client_conn_info);
let svc = axum_server::service::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
&mut app,
client_conn_info,
);
let svc = svc.await.map_err(|e| {
error!("Failed to build HTTP response: {:?}", e);
std::io::Error::from(ErrorKind::Other)
})?;
let stream = TokioIo::new(tls_stream);
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
svc.clone().call(request)
});
hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await
.map_err(|e| {
debug!("Failed to complete connection: {:?}", e);
std::io::Error::from(ErrorKind::ConnectionAborted)
})
}
Err(error) => {
trace!("Failed to handle connection: {:?}", error);
Ok(())
}
}
}