Skip to main content

kanidmd_core/
admin.rs

1use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
2use crate::repl::ReplCtrl;
3use crate::CoreAction;
4use bytes::{BufMut, BytesMut};
5use crypto_glue::x509::x509b64;
6use futures::{SinkExt, StreamExt};
7pub use kanidm_proto::internal::{
8    DomainInfo as ProtoDomainInfo, DomainUpgradeCheckReport as ProtoDomainUpgradeCheckReport,
9    DomainUpgradeCheckStatus as ProtoDomainUpgradeCheckStatus,
10};
11use kanidm_utils_users::get_current_uid;
12use serde::{Deserialize, Serialize};
13use std::error::Error;
14use std::io;
15use std::time::Duration;
16use tokio::net::{UnixListener, UnixStream};
17use tokio::sync::broadcast;
18use tokio::sync::mpsc;
19use tokio::sync::oneshot;
20use tokio::time::timeout;
21use tokio_util::codec::{Decoder, Encoder, Framed};
22use tracing::{span, Instrument, Level};
23use uuid::Uuid;
24
25/// Don't hang forever waiting for a response
26const REPL_CTRL_TIMEOUT: Duration = Duration::from_secs(15);
27
28#[derive(Serialize, Deserialize, Debug)]
29pub enum AdminTaskRequest {
30    RecoverAccount { name: String },
31    DisableAccount { name: String },
32    ShowReplicationCertificate,
33    ShowReplicationCertificateMetadata,
34    RenewReplicationCertificate,
35    RefreshReplicationConsumer,
36    DomainShow,
37    DomainUpgradeCheck,
38    DomainRaise,
39    DomainRemigrate { level: Option<u32> },
40    Reload,
41}
42
43#[derive(Serialize, Deserialize)]
44pub enum AdminTaskResponse {
45    RecoverAccount {
46        password: String,
47    },
48    ShowReplicationCertificate {
49        cert: String,
50    },
51    ShowReplicationCertificateMetadata {
52        not_before: String,
53        not_after: String,
54        subject: String,
55        expired: bool,
56    },
57    DomainUpgradeCheck {
58        report: ProtoDomainUpgradeCheckReport,
59    },
60    DomainRaise {
61        level: u32,
62    },
63    DomainShow {
64        domain_info: ProtoDomainInfo,
65    },
66    Success,
67    Error,
68}
69
70impl std::fmt::Debug for AdminTaskResponse {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        match self {
73            // the intent here is that we aren't sharing secret material in logs
74            AdminTaskResponse::RecoverAccount { .. } => write!(f, "RecoverAccount {{ .. }}"),
75            // the intent here is that we aren't sharing secret material in logs
76            AdminTaskResponse::ShowReplicationCertificate { .. } => {
77                write!(f, "ShowReplicationCertificate {{ .. }}",)
78            }
79            AdminTaskResponse::ShowReplicationCertificateMetadata {
80                not_before,
81                not_after,
82                subject,
83                expired,
84            } => {
85                write!(f, "ShowReplicationCertificateMetadata {{ not_before: {:?}, not_after: {:?}, subject: {:?}, expired: {} }}", not_before, not_after, subject, expired)
86            }
87            AdminTaskResponse::DomainUpgradeCheck { report } => {
88                write!(f, "DomainUpgradeCheck {{ report: {:?} }}", report)
89            }
90            AdminTaskResponse::DomainRaise { level } => {
91                write!(f, "DomainRaise {{ level: {} }}", level)
92            }
93            AdminTaskResponse::DomainShow { domain_info } => {
94                write!(f, "DomainShow {{ domain_info: {:?} }}", domain_info)
95            }
96            AdminTaskResponse::Success => write!(f, "Success"),
97            AdminTaskResponse::Error => write!(f, "Error"),
98        }
99    }
100}
101
102#[derive(Default)]
103pub struct ClientCodec;
104
105impl Decoder for ClientCodec {
106    type Error = io::Error;
107    type Item = AdminTaskResponse;
108
109    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
110        trace!("Attempting to decode request ...");
111        match serde_json::from_slice::<AdminTaskResponse>(src) {
112            Ok(msg) => {
113                // Clear the buffer for the next message.
114                src.clear();
115                Ok(Some(msg))
116            }
117            _ => Ok(None),
118        }
119    }
120}
121
122impl Encoder<AdminTaskRequest> for ClientCodec {
123    type Error = io::Error;
124
125    fn encode(&mut self, msg: AdminTaskRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
126        trace!("Attempting to send response -> {:?} ...", msg);
127        let data = serde_json::to_vec(&msg).map_err(|e| {
128            error!("socket encoding error -> {:?}", e);
129            io::Error::other("JSON encode error")
130        })?;
131        dst.put(data.as_slice());
132        Ok(())
133    }
134}
135
136#[derive(Default)]
137struct ServerCodec;
138
139impl Decoder for ServerCodec {
140    type Error = io::Error;
141    type Item = AdminTaskRequest;
142
143    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
144        trace!("Attempting to decode request ...");
145        match serde_json::from_slice::<AdminTaskRequest>(src) {
146            Ok(msg) => {
147                // Clear the buffer for the next message.
148                src.clear();
149                Ok(Some(msg))
150            }
151            _ => Ok(None),
152        }
153    }
154}
155
156impl Encoder<AdminTaskResponse> for ServerCodec {
157    type Error = io::Error;
158
159    fn encode(&mut self, msg: AdminTaskResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
160        trace!("Attempting to send response -> {:?} ...", msg);
161        let data = serde_json::to_vec(&msg).map_err(|e| {
162            error!("socket encoding error -> {:?}", e);
163            io::Error::other("JSON encode error")
164        })?;
165        dst.put(data.as_slice());
166        Ok(())
167    }
168}
169
170pub(crate) struct AdminActor;
171
172impl AdminActor {
173    pub async fn create_admin_sock(
174        sock_path: &str,
175        server_rw: &'static QueryServerWriteV1,
176        server_ro: &'static QueryServerReadV1,
177        broadcast_tx: broadcast::Sender<CoreAction>,
178        repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
179    ) -> Result<tokio::task::JoinHandle<()>, ()> {
180        debug!("🧹 Cleaning up sockets from previous invocations");
181        rm_if_exist(sock_path);
182
183        // Setup the unix socket.
184        let listener = match UnixListener::bind(sock_path) {
185            Ok(l) => l,
186            Err(e) => {
187                error!(err = ?e, "Failed to bind UNIX socket {}", sock_path);
188                return Err(());
189            }
190        };
191
192        let mut broadcast_rx = broadcast_tx.subscribe();
193
194        // what is the uid we are running as?
195        let cuid = get_current_uid();
196
197        let handle = tokio::spawn(async move {
198            loop {
199                tokio::select! {
200                    Ok(action) = broadcast_rx.recv() => {
201                        match action {
202                            CoreAction::Shutdown => break,
203                            CoreAction::Reload => {},
204                        }
205                    }
206                    accept_res = listener.accept() => {
207                        match accept_res {
208                            Ok((socket, _addr)) => {
209                                // Assert that the incoming connection is from root or
210                                // our own uid.
211                                // ⚠️  This underpins the security of this socket ⚠️
212                                if let Ok(ucred) = socket.peer_cred() {
213                                    let incoming_uid = ucred.uid();
214                                    if incoming_uid == 0 || incoming_uid == cuid {
215                                        // all good!
216                                        info!(pid = ?ucred.pid(), "Allowing admin socket access");
217                                    } else {
218                                        warn!(%incoming_uid, "unauthorised user");
219                                        continue;
220                                    }
221                                } else {
222                                    error!("unable to determine peer credentials");
223                                    continue;
224                                };
225
226                                // spawn the worker.
227                                let task_repl_ctrl_tx = repl_ctrl_tx.clone();
228                                let broadcast_tx_ = broadcast_tx.clone();
229                                tokio::spawn(async move {
230                                    if let Err(e) = handle_client(socket, server_rw, server_ro, task_repl_ctrl_tx, broadcast_tx_).await {
231                                        error!(err = ?e, "admin client error");
232                                    }
233                                });
234                            }
235                            Err(e) => {
236                                warn!(err = ?e, "admin socket accept error");
237                            }
238                        }
239                    }
240                }
241            }
242            info!("Stopped {}", super::TaskName::AdminSocket);
243        });
244        Ok(handle)
245    }
246}
247
248fn rm_if_exist(p: &str) {
249    debug!("Attempting to remove requested file {}", p);
250    let _ = std::fs::remove_file(p).map_err(|e| match e.kind() {
251        std::io::ErrorKind::NotFound => {
252            debug!("{} not present, no need to remove.", p);
253        }
254        _ => {
255            error!(
256                "Failure while attempting to attempting to remove {} -> {}",
257                p,
258                e.to_string()
259            );
260        }
261    });
262}
263
264async fn show_replication_certificate_metadata(
265    ctrl_tx: &mut mpsc::Sender<ReplCtrl>,
266) -> AdminTaskResponse {
267    let (tx, rx) = oneshot::channel();
268
269    if ctrl_tx
270        .send(ReplCtrl::GetCertificate { respond: tx })
271        .await
272        .is_err()
273    {
274        error!("replication control channel has shutdown");
275        AdminTaskResponse::Error
276    } else {
277        match timeout(REPL_CTRL_TIMEOUT, rx).await {
278            Ok(Ok(cert)) => {
279                let cert_not_after = cert.tbs_certificate.validity.not_after;
280                let cert_not_before = cert.tbs_certificate.validity.not_before;
281                let subject = cert.tbs_certificate.subject.to_string();
282
283                let expired = cert_not_after.to_system_time() < std::time::SystemTime::now();
284                AdminTaskResponse::ShowReplicationCertificateMetadata {
285                    expired,
286                    not_before: cert_not_before.to_string(),
287                    not_after: cert_not_after.to_string(),
288                    subject,
289                }
290            }
291            Ok(Err(_)) => {
292                error!("replication control channel did not respond with certificate.");
293                AdminTaskResponse::Error
294            }
295            Err(_) => {
296                error!("timed out waiting for replication certificate metadata.");
297                AdminTaskResponse::Error
298            }
299        }
300    }
301}
302
303async fn show_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
304    let (tx, rx) = oneshot::channel();
305
306    if ctrl_tx
307        .send(ReplCtrl::GetCertificate { respond: tx })
308        .await
309        .is_err()
310    {
311        error!("replication control channel has shutdown");
312        return AdminTaskResponse::Error;
313    }
314
315    match timeout(REPL_CTRL_TIMEOUT, rx).await {
316        Ok(Ok(cert)) => x509b64::cert_to_string(&cert)
317            .map(|cert| AdminTaskResponse::ShowReplicationCertificate { cert })
318            .unwrap_or(AdminTaskResponse::Error),
319        Ok(Err(_)) => {
320            error!("replication control channel did not respond with certificate.");
321            AdminTaskResponse::Error
322        }
323        Err(_) => {
324            error!("timed out waiting for replication certificate response.");
325            AdminTaskResponse::Error
326        }
327    }
328}
329
330async fn renew_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
331    let (tx, rx) = oneshot::channel();
332
333    if ctrl_tx
334        .send(ReplCtrl::RenewCertificate { respond: tx })
335        .await
336        .is_err()
337    {
338        error!("replication control channel has shutdown");
339        return AdminTaskResponse::Error;
340    }
341
342    match timeout(REPL_CTRL_TIMEOUT, rx).await {
343        Ok(Ok(success)) => {
344            if success {
345                show_replication_certificate(ctrl_tx).await
346            } else {
347                error!("replication control channel indicated that certificate renewal failed.");
348                AdminTaskResponse::Error
349            }
350        }
351        Ok(Err(_)) => {
352            error!("replication control channel did not respond with renewal status.");
353            AdminTaskResponse::Error
354        }
355        Err(_) => {
356            error!("timed out waiting for replication renewal status.");
357            AdminTaskResponse::Error
358        }
359    }
360}
361
362async fn replication_consumer_refresh(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
363    let (tx, rx) = oneshot::channel();
364
365    if ctrl_tx
366        .send(ReplCtrl::RefreshConsumer { respond: tx })
367        .await
368        .is_err()
369    {
370        error!("replication control channel has shutdown");
371        return AdminTaskResponse::Error;
372    }
373
374    match timeout(REPL_CTRL_TIMEOUT, rx).await {
375        Ok(Ok(mut refresh_rx)) => match timeout(REPL_CTRL_TIMEOUT, refresh_rx.recv()).await {
376            Ok(Some(())) => {
377                info!("Replication refresh success");
378                AdminTaskResponse::Success
379            }
380            Ok(None) => {
381                error!("Replication refresh failed. Please inspect the logs.");
382                AdminTaskResponse::Error
383            }
384            Err(_) => {
385                error!("timed out waiting for replication refresh completion.");
386                AdminTaskResponse::Error
387            }
388        },
389        Ok(Err(_)) => {
390            error!("replication control channel did not respond with refresh status.");
391            AdminTaskResponse::Error
392        }
393        Err(_) => {
394            error!("timed out waiting for replication refresh status.");
395            AdminTaskResponse::Error
396        }
397    }
398}
399
400async fn handle_client(
401    sock: UnixStream,
402    server_rw: &'static QueryServerWriteV1,
403    server_ro: &'static QueryServerReadV1,
404    mut repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
405    broadcast_tx: broadcast::Sender<CoreAction>,
406) -> Result<(), Box<dyn Error>> {
407    debug!("Accepted admin socket connection");
408
409    let mut reqs = Framed::new(sock, ServerCodec);
410
411    trace!("Waiting for requests ...");
412    while let Some(Ok(req)) = reqs.next().await {
413        // Setup the logging span
414        let eventid = Uuid::new_v4();
415        let nspan = span!(Level::INFO, "handle_admin_client_request", uuid = ?eventid);
416
417        let resp = async {
418            match req {
419                AdminTaskRequest::RecoverAccount { name } => {
420                    match server_rw.handle_admin_recover_account(name, eventid).await {
421                        Ok(password) => AdminTaskResponse::RecoverAccount { password },
422                        Err(e) => {
423                            error!(err = ?e, "error during recover-account");
424                            AdminTaskResponse::Error
425                        }
426                    }
427                }
428                AdminTaskRequest::DisableAccount { name } => {
429                    match server_rw.handle_admin_disable_account(name, eventid).await {
430                        Ok(()) => AdminTaskResponse::Success,
431                        Err(e) => {
432                            error!(err = ?e, "error during disable-account");
433                            AdminTaskResponse::Error
434                        }
435                    }
436                }
437                AdminTaskRequest::ShowReplicationCertificate => match repl_ctrl_tx.as_mut() {
438                    Some(ctrl_tx) => show_replication_certificate(ctrl_tx).await,
439                    None => {
440                        error!("replication not configured, unable to display certificate.");
441                        AdminTaskResponse::Error
442                    }
443                },
444                AdminTaskRequest::ShowReplicationCertificateMetadata => match repl_ctrl_tx.as_mut() {
445                    Some(ctrl_tx) => {
446                        show_replication_certificate_metadata(ctrl_tx).await
447                    }
448                    None => {
449                        error!("replication not configured, unable to display certificate metadata.");
450                        AdminTaskResponse::Error
451                    }
452                },
453                AdminTaskRequest::RenewReplicationCertificate => match repl_ctrl_tx.as_mut() {
454                    Some(ctrl_tx) => renew_replication_certificate(ctrl_tx).await,
455                    None => {
456                        error!("replication not configured, unable to renew certificate.");
457                        AdminTaskResponse::Error
458                    }
459                },
460                AdminTaskRequest::RefreshReplicationConsumer => match repl_ctrl_tx.as_mut() {
461                    Some(ctrl_tx) => replication_consumer_refresh(ctrl_tx).await,
462                    None => {
463                        error!("replication not configured, unable to refresh consumer.");
464                        AdminTaskResponse::Error
465                    }
466                },
467
468                AdminTaskRequest::DomainShow => match server_ro.handle_domain_show(eventid).await {
469                    Ok(domain_info) => AdminTaskResponse::DomainShow { domain_info },
470                    Err(e) => {
471                        error!(err = ?e, "error during domain show");
472                        AdminTaskResponse::Error
473                    }
474                },
475                AdminTaskRequest::DomainUpgradeCheck => {
476                    match server_ro.handle_domain_upgrade_check(eventid).await {
477                        Ok(report) => AdminTaskResponse::DomainUpgradeCheck { report },
478                        Err(e) => {
479                            error!(err = ?e, "error during domain upgrade checkr");
480                            AdminTaskResponse::Error
481                        }
482                    }
483                }
484                AdminTaskRequest::DomainRaise => match server_rw.handle_domain_raise(eventid).await
485                {
486                    Ok(level) => AdminTaskResponse::DomainRaise { level },
487                    Err(e) => {
488                        error!(err = ?e, "error during domain raise");
489                        AdminTaskResponse::Error
490                    }
491                },
492                AdminTaskRequest::DomainRemigrate { level } => {
493                    match server_rw.handle_domain_remigrate(level, eventid).await {
494                        Ok(()) => AdminTaskResponse::Success,
495                        Err(e) => {
496                            error!(err = ?e, "error during domain remigrate");
497                            AdminTaskResponse::Error
498                        }
499                    }
500                }
501                AdminTaskRequest::Reload => match broadcast_tx.send(CoreAction::Reload) {
502                    Ok(_) => AdminTaskResponse::Success,
503                    Err(e) => {
504                        error!(err = ?e, "error during server reload");
505                        AdminTaskResponse::Error
506                    }
507                },
508            }
509        }
510        .instrument(nspan)
511        .await;
512
513        reqs.send(resp).await?;
514        reqs.flush().await?;
515    }
516
517    debug!("Disconnecting client ...");
518    Ok(())
519}