kanidmd_core/
admin.rs

1use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
2use crate::repl::ReplCtrl;
3use crate::CoreAction;
4use bytes::{BufMut, BytesMut};
5use futures::{SinkExt, StreamExt};
6use kanidm_lib_crypto::serialise::x509b64;
7use kanidm_utils_users::get_current_uid;
8use serde::{Deserialize, Serialize};
9use std::error::Error;
10use std::io;
11use std::path::Path;
12use tokio::net::{UnixListener, UnixStream};
13use tokio::sync::broadcast;
14use tokio::sync::mpsc;
15use tokio::sync::oneshot;
16use tokio_util::codec::{Decoder, Encoder, Framed};
17use tracing::{span, Instrument, Level};
18use uuid::Uuid;
19
20pub use kanidm_proto::internal::{
21    DomainInfo as ProtoDomainInfo, DomainUpgradeCheckReport as ProtoDomainUpgradeCheckReport,
22    DomainUpgradeCheckStatus as ProtoDomainUpgradeCheckStatus,
23};
24
25#[derive(Serialize, Deserialize, Debug)]
26pub enum AdminTaskRequest {
27    RecoverAccount { name: String },
28    DisableAccount { name: String },
29    ShowReplicationCertificate,
30    RenewReplicationCertificate,
31    RefreshReplicationConsumer,
32    DomainShow,
33    DomainUpgradeCheck,
34    DomainRaise,
35    DomainRemigrate { level: Option<u32> },
36}
37
38#[derive(Serialize, Deserialize)]
39pub enum AdminTaskResponse {
40    RecoverAccount {
41        password: String,
42    },
43    ShowReplicationCertificate {
44        cert: String,
45    },
46    DomainUpgradeCheck {
47        report: ProtoDomainUpgradeCheckReport,
48    },
49    DomainRaise {
50        level: u32,
51    },
52    DomainShow {
53        domain_info: ProtoDomainInfo,
54    },
55    Success,
56    Error,
57}
58
59impl std::fmt::Debug for AdminTaskResponse {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            // the intent here is that we aren't sharing secret material in logs
63            AdminTaskResponse::RecoverAccount { .. } => write!(f, "RecoverAccount {{ .. }}"),
64            // the intent here is that we aren't sharing secret material in logs
65            AdminTaskResponse::ShowReplicationCertificate { .. } => {
66                write!(f, "ShowReplicationCertificate {{ .. }}",)
67            }
68            AdminTaskResponse::DomainUpgradeCheck { report } => {
69                write!(f, "DomainUpgradeCheck {{ report: {:?} }}", report)
70            }
71            AdminTaskResponse::DomainRaise { level } => {
72                write!(f, "DomainRaise {{ level: {} }}", level)
73            }
74            AdminTaskResponse::DomainShow { domain_info } => {
75                write!(f, "DomainShow {{ domain_info: {:?} }}", domain_info)
76            }
77            AdminTaskResponse::Success => write!(f, "Success"),
78            AdminTaskResponse::Error => write!(f, "Error"),
79        }
80    }
81}
82
83#[derive(Default)]
84pub struct ClientCodec;
85
86impl Decoder for ClientCodec {
87    type Error = io::Error;
88    type Item = AdminTaskResponse;
89
90    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
91        trace!("Attempting to decode request ...");
92        match serde_json::from_slice::<AdminTaskResponse>(src) {
93            Ok(msg) => {
94                // Clear the buffer for the next message.
95                src.clear();
96                Ok(Some(msg))
97            }
98            _ => Ok(None),
99        }
100    }
101}
102
103impl Encoder<AdminTaskRequest> for ClientCodec {
104    type Error = io::Error;
105
106    fn encode(&mut self, msg: AdminTaskRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
107        trace!("Attempting to send response -> {:?} ...", msg);
108        let data = serde_json::to_vec(&msg).map_err(|e| {
109            error!("socket encoding error -> {:?}", e);
110            io::Error::other("JSON encode error")
111        })?;
112        dst.put(data.as_slice());
113        Ok(())
114    }
115}
116
117#[derive(Default)]
118struct ServerCodec;
119
120impl Decoder for ServerCodec {
121    type Error = io::Error;
122    type Item = AdminTaskRequest;
123
124    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
125        trace!("Attempting to decode request ...");
126        match serde_json::from_slice::<AdminTaskRequest>(src) {
127            Ok(msg) => {
128                // Clear the buffer for the next message.
129                src.clear();
130                Ok(Some(msg))
131            }
132            _ => Ok(None),
133        }
134    }
135}
136
137impl Encoder<AdminTaskResponse> for ServerCodec {
138    type Error = io::Error;
139
140    fn encode(&mut self, msg: AdminTaskResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
141        trace!("Attempting to send response -> {:?} ...", msg);
142        let data = serde_json::to_vec(&msg).map_err(|e| {
143            error!("socket encoding error -> {:?}", e);
144            io::Error::other("JSON encode error")
145        })?;
146        dst.put(data.as_slice());
147        Ok(())
148    }
149}
150
151pub(crate) struct AdminActor;
152
153impl AdminActor {
154    pub async fn create_admin_sock(
155        sock_path: &str,
156        server_rw: &'static QueryServerWriteV1,
157        server_ro: &'static QueryServerReadV1,
158        mut broadcast_rx: broadcast::Receiver<CoreAction>,
159        repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
160    ) -> Result<tokio::task::JoinHandle<()>, ()> {
161        debug!("🧹 Cleaning up sockets from previous invocations");
162        rm_if_exist(sock_path);
163
164        // Setup the unix socket.
165        let listener = match UnixListener::bind(sock_path) {
166            Ok(l) => l,
167            Err(e) => {
168                error!(err = ?e, "Failed to bind UNIX socket {}", sock_path);
169                return Err(());
170            }
171        };
172
173        // what is the uid we are running as?
174        let cuid = get_current_uid();
175
176        let handle = tokio::spawn(async move {
177            loop {
178                tokio::select! {
179                    Ok(action) = broadcast_rx.recv() => {
180                        match action {
181                            CoreAction::Shutdown => break,
182                        }
183                    }
184                    accept_res = listener.accept() => {
185                        match accept_res {
186                            Ok((socket, _addr)) => {
187                                // Assert that the incoming connection is from root or
188                                // our own uid.
189                                // ⚠️  This underpins the security of this socket ⚠️
190                                if let Ok(ucred) = socket.peer_cred() {
191                                    let incoming_uid = ucred.uid();
192                                    if incoming_uid == 0 || incoming_uid == cuid {
193                                        // all good!
194                                        info!(pid = ?ucred.pid(), "Allowing admin socket access");
195                                    } else {
196                                        warn!(%incoming_uid, "unauthorised user");
197                                        continue;
198                                    }
199                                } else {
200                                    error!("unable to determine peer credentials");
201                                    continue;
202                                };
203
204                                // spawn the worker.
205                                let task_repl_ctrl_tx = repl_ctrl_tx.clone();
206                                tokio::spawn(async move {
207                                    if let Err(e) = handle_client(socket, server_rw, server_ro, task_repl_ctrl_tx).await {
208                                        error!(err = ?e, "admin client error");
209                                    }
210                                });
211                            }
212                            Err(e) => {
213                                warn!(err = ?e, "admin socket accept error");
214                            }
215                        }
216                    }
217                }
218            }
219            info!("Stopped {}", super::TaskName::AdminSocket);
220        });
221        Ok(handle)
222    }
223}
224
225fn rm_if_exist(p: &str) {
226    if Path::new(p).exists() {
227        debug!("Removing requested file {:?}", p);
228        let _ = std::fs::remove_file(p).map_err(|e| {
229            error!(
230                "Failure while attempting to attempting to remove {:?} -> {:?}",
231                p, e
232            );
233        });
234    } else {
235        debug!("Path {:?} doesn't exist, not attempting to remove.", p);
236    }
237}
238
239async fn show_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
240    let (tx, rx) = oneshot::channel();
241
242    if ctrl_tx
243        .send(ReplCtrl::GetCertificate { respond: tx })
244        .await
245        .is_err()
246    {
247        error!("replication control channel has shutdown");
248        return AdminTaskResponse::Error;
249    }
250
251    match rx.await {
252        Ok(cert) => x509b64::cert_to_string(&cert)
253            .map(|cert| AdminTaskResponse::ShowReplicationCertificate { cert })
254            .unwrap_or(AdminTaskResponse::Error),
255        Err(_) => {
256            error!("replication control channel did not respond with certificate.");
257            AdminTaskResponse::Error
258        }
259    }
260}
261
262async fn renew_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
263    let (tx, rx) = oneshot::channel();
264
265    if ctrl_tx
266        .send(ReplCtrl::RenewCertificate { respond: tx })
267        .await
268        .is_err()
269    {
270        error!("replication control channel has shutdown");
271        return AdminTaskResponse::Error;
272    }
273
274    match rx.await {
275        Ok(success) => {
276            if success {
277                show_replication_certificate(ctrl_tx).await
278            } else {
279                error!("replication control channel indicated that certificate renewal failed.");
280                AdminTaskResponse::Error
281            }
282        }
283        Err(_) => {
284            error!("replication control channel did not respond with renewal status.");
285            AdminTaskResponse::Error
286        }
287    }
288}
289
290async fn replication_consumer_refresh(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
291    let (tx, rx) = oneshot::channel();
292
293    if ctrl_tx
294        .send(ReplCtrl::RefreshConsumer { respond: tx })
295        .await
296        .is_err()
297    {
298        error!("replication control channel has shutdown");
299        return AdminTaskResponse::Error;
300    }
301
302    match rx.await {
303        Ok(mut refresh_rx) => {
304            if let Some(()) = refresh_rx.recv().await {
305                info!("Replication refresh success");
306                AdminTaskResponse::Success
307            } else {
308                error!("Replication refresh failed. Please inspect the logs.");
309                AdminTaskResponse::Error
310            }
311        }
312        Err(_) => {
313            error!("replication control channel did not respond with refresh status.");
314            AdminTaskResponse::Error
315        }
316    }
317}
318
319async fn handle_client(
320    sock: UnixStream,
321    server_rw: &'static QueryServerWriteV1,
322    server_ro: &'static QueryServerReadV1,
323    mut repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
324) -> Result<(), Box<dyn Error>> {
325    debug!("Accepted admin socket connection");
326
327    let mut reqs = Framed::new(sock, ServerCodec);
328
329    trace!("Waiting for requests ...");
330    while let Some(Ok(req)) = reqs.next().await {
331        // Setup the logging span
332        let eventid = Uuid::new_v4();
333        let nspan = span!(Level::INFO, "handle_admin_client_request", uuid = ?eventid);
334
335        let resp = async {
336            match req {
337                AdminTaskRequest::RecoverAccount { name } => {
338                    match server_rw.handle_admin_recover_account(name, eventid).await {
339                        Ok(password) => AdminTaskResponse::RecoverAccount { password },
340                        Err(e) => {
341                            error!(err = ?e, "error during recover-account");
342                            AdminTaskResponse::Error
343                        }
344                    }
345                }
346                AdminTaskRequest::DisableAccount { name } => {
347                    match server_rw.handle_admin_disable_account(name, eventid).await {
348                        Ok(()) => AdminTaskResponse::Success,
349                        Err(e) => {
350                            error!(err = ?e, "error during disable-account");
351                            AdminTaskResponse::Error
352                        }
353                    }
354                }
355                AdminTaskRequest::ShowReplicationCertificate => match repl_ctrl_tx.as_mut() {
356                    Some(ctrl_tx) => show_replication_certificate(ctrl_tx).await,
357                    None => {
358                        error!("replication not configured, unable to display certificate.");
359                        AdminTaskResponse::Error
360                    }
361                },
362                AdminTaskRequest::RenewReplicationCertificate => match repl_ctrl_tx.as_mut() {
363                    Some(ctrl_tx) => renew_replication_certificate(ctrl_tx).await,
364                    None => {
365                        error!("replication not configured, unable to renew certificate.");
366                        AdminTaskResponse::Error
367                    }
368                },
369                AdminTaskRequest::RefreshReplicationConsumer => match repl_ctrl_tx.as_mut() {
370                    Some(ctrl_tx) => replication_consumer_refresh(ctrl_tx).await,
371                    None => {
372                        error!("replication not configured, unable to refresh consumer.");
373                        AdminTaskResponse::Error
374                    }
375                },
376
377                AdminTaskRequest::DomainShow => match server_ro.handle_domain_show(eventid).await {
378                    Ok(domain_info) => AdminTaskResponse::DomainShow { domain_info },
379                    Err(e) => {
380                        error!(err = ?e, "error during domain show");
381                        AdminTaskResponse::Error
382                    }
383                },
384                AdminTaskRequest::DomainUpgradeCheck => {
385                    match server_ro.handle_domain_upgrade_check(eventid).await {
386                        Ok(report) => AdminTaskResponse::DomainUpgradeCheck { report },
387                        Err(e) => {
388                            error!(err = ?e, "error during domain upgrade checkr");
389                            AdminTaskResponse::Error
390                        }
391                    }
392                }
393                AdminTaskRequest::DomainRaise => match server_rw.handle_domain_raise(eventid).await
394                {
395                    Ok(level) => AdminTaskResponse::DomainRaise { level },
396                    Err(e) => {
397                        error!(err = ?e, "error during domain raise");
398                        AdminTaskResponse::Error
399                    }
400                },
401                AdminTaskRequest::DomainRemigrate { level } => {
402                    match server_rw.handle_domain_remigrate(level, eventid).await {
403                        Ok(()) => AdminTaskResponse::Success,
404                        Err(e) => {
405                            error!(err = ?e, "error during domain remigrate");
406                            AdminTaskResponse::Error
407                        }
408                    }
409                }
410            }
411        }
412        .instrument(nspan)
413        .await;
414
415        reqs.send(resp).await?;
416        reqs.flush().await?;
417    }
418
419    debug!("Disconnecting client ...");
420    Ok(())
421}