1use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
2use crate::repl::ReplCtrl;
3use crate::CoreAction;
4use bytes::{BufMut, BytesMut};
5use futures::{SinkExt, StreamExt};
6use kanidm_lib_crypto::serialise::x509b64;
7use kanidm_utils_users::get_current_uid;
8use serde::{Deserialize, Serialize};
9use std::error::Error;
10use std::io;
11use tokio::net::{UnixListener, UnixStream};
12use tokio::sync::broadcast;
13use tokio::sync::mpsc;
14use tokio::sync::oneshot;
15use tokio_util::codec::{Decoder, Encoder, Framed};
16use tracing::{span, Instrument, Level};
17use uuid::Uuid;
18
19pub use kanidm_proto::internal::{
20 DomainInfo as ProtoDomainInfo, DomainUpgradeCheckReport as ProtoDomainUpgradeCheckReport,
21 DomainUpgradeCheckStatus as ProtoDomainUpgradeCheckStatus,
22};
23
24#[derive(Serialize, Deserialize, Debug)]
25pub enum AdminTaskRequest {
26 RecoverAccount { name: String },
27 DisableAccount { name: String },
28 ShowReplicationCertificate,
29 RenewReplicationCertificate,
30 RefreshReplicationConsumer,
31 DomainShow,
32 DomainUpgradeCheck,
33 DomainRaise,
34 DomainRemigrate { level: Option<u32> },
35 Reload,
36}
37
38#[derive(Serialize, Deserialize)]
39pub enum AdminTaskResponse {
40 RecoverAccount {
41 password: String,
42 },
43 ShowReplicationCertificate {
44 cert: String,
45 },
46 DomainUpgradeCheck {
47 report: ProtoDomainUpgradeCheckReport,
48 },
49 DomainRaise {
50 level: u32,
51 },
52 DomainShow {
53 domain_info: ProtoDomainInfo,
54 },
55 Success,
56 Error,
57}
58
59impl std::fmt::Debug for AdminTaskResponse {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 AdminTaskResponse::RecoverAccount { .. } => write!(f, "RecoverAccount {{ .. }}"),
64 AdminTaskResponse::ShowReplicationCertificate { .. } => {
66 write!(f, "ShowReplicationCertificate {{ .. }}",)
67 }
68 AdminTaskResponse::DomainUpgradeCheck { report } => {
69 write!(f, "DomainUpgradeCheck {{ report: {:?} }}", report)
70 }
71 AdminTaskResponse::DomainRaise { level } => {
72 write!(f, "DomainRaise {{ level: {} }}", level)
73 }
74 AdminTaskResponse::DomainShow { domain_info } => {
75 write!(f, "DomainShow {{ domain_info: {:?} }}", domain_info)
76 }
77 AdminTaskResponse::Success => write!(f, "Success"),
78 AdminTaskResponse::Error => write!(f, "Error"),
79 }
80 }
81}
82
83#[derive(Default)]
84pub struct ClientCodec;
85
86impl Decoder for ClientCodec {
87 type Error = io::Error;
88 type Item = AdminTaskResponse;
89
90 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
91 trace!("Attempting to decode request ...");
92 match serde_json::from_slice::<AdminTaskResponse>(src) {
93 Ok(msg) => {
94 src.clear();
96 Ok(Some(msg))
97 }
98 _ => Ok(None),
99 }
100 }
101}
102
103impl Encoder<AdminTaskRequest> for ClientCodec {
104 type Error = io::Error;
105
106 fn encode(&mut self, msg: AdminTaskRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
107 trace!("Attempting to send response -> {:?} ...", msg);
108 let data = serde_json::to_vec(&msg).map_err(|e| {
109 error!("socket encoding error -> {:?}", e);
110 io::Error::other("JSON encode error")
111 })?;
112 dst.put(data.as_slice());
113 Ok(())
114 }
115}
116
117#[derive(Default)]
118struct ServerCodec;
119
120impl Decoder for ServerCodec {
121 type Error = io::Error;
122 type Item = AdminTaskRequest;
123
124 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
125 trace!("Attempting to decode request ...");
126 match serde_json::from_slice::<AdminTaskRequest>(src) {
127 Ok(msg) => {
128 src.clear();
130 Ok(Some(msg))
131 }
132 _ => Ok(None),
133 }
134 }
135}
136
137impl Encoder<AdminTaskResponse> for ServerCodec {
138 type Error = io::Error;
139
140 fn encode(&mut self, msg: AdminTaskResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
141 trace!("Attempting to send response -> {:?} ...", msg);
142 let data = serde_json::to_vec(&msg).map_err(|e| {
143 error!("socket encoding error -> {:?}", e);
144 io::Error::other("JSON encode error")
145 })?;
146 dst.put(data.as_slice());
147 Ok(())
148 }
149}
150
151pub(crate) struct AdminActor;
152
153impl AdminActor {
154 pub async fn create_admin_sock(
155 sock_path: &str,
156 server_rw: &'static QueryServerWriteV1,
157 server_ro: &'static QueryServerReadV1,
158 broadcast_tx: broadcast::Sender<CoreAction>,
159 repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
160 ) -> Result<tokio::task::JoinHandle<()>, ()> {
161 debug!("🧹 Cleaning up sockets from previous invocations");
162 rm_if_exist(sock_path);
163
164 let listener = match UnixListener::bind(sock_path) {
166 Ok(l) => l,
167 Err(e) => {
168 error!(err = ?e, "Failed to bind UNIX socket {}", sock_path);
169 return Err(());
170 }
171 };
172
173 let mut broadcast_rx = broadcast_tx.subscribe();
174
175 let cuid = get_current_uid();
177
178 let handle = tokio::spawn(async move {
179 loop {
180 tokio::select! {
181 Ok(action) = broadcast_rx.recv() => {
182 match action {
183 CoreAction::Shutdown => break,
184 CoreAction::Reload => {},
185 }
186 }
187 accept_res = listener.accept() => {
188 match accept_res {
189 Ok((socket, _addr)) => {
190 if let Ok(ucred) = socket.peer_cred() {
194 let incoming_uid = ucred.uid();
195 if incoming_uid == 0 || incoming_uid == cuid {
196 info!(pid = ?ucred.pid(), "Allowing admin socket access");
198 } else {
199 warn!(%incoming_uid, "unauthorised user");
200 continue;
201 }
202 } else {
203 error!("unable to determine peer credentials");
204 continue;
205 };
206
207 let task_repl_ctrl_tx = repl_ctrl_tx.clone();
209 let broadcast_tx_ = broadcast_tx.clone();
210 tokio::spawn(async move {
211 if let Err(e) = handle_client(socket, server_rw, server_ro, task_repl_ctrl_tx, broadcast_tx_).await {
212 error!(err = ?e, "admin client error");
213 }
214 });
215 }
216 Err(e) => {
217 warn!(err = ?e, "admin socket accept error");
218 }
219 }
220 }
221 }
222 }
223 info!("Stopped {}", super::TaskName::AdminSocket);
224 });
225 Ok(handle)
226 }
227}
228
229fn rm_if_exist(p: &str) {
230 debug!("Attempting to remove requested file {}", p);
231 let _ = std::fs::remove_file(p).map_err(|e| match e.kind() {
232 std::io::ErrorKind::NotFound => {
233 debug!("{} not present, no need to remove.", p);
234 }
235 _ => {
236 error!(
237 "Failure while attempting to attempting to remove {} -> {}",
238 p,
239 e.to_string()
240 );
241 }
242 });
243}
244
245async fn show_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
246 let (tx, rx) = oneshot::channel();
247
248 if ctrl_tx
249 .send(ReplCtrl::GetCertificate { respond: tx })
250 .await
251 .is_err()
252 {
253 error!("replication control channel has shutdown");
254 return AdminTaskResponse::Error;
255 }
256
257 match rx.await {
258 Ok(cert) => x509b64::cert_to_string(&cert)
259 .map(|cert| AdminTaskResponse::ShowReplicationCertificate { cert })
260 .unwrap_or(AdminTaskResponse::Error),
261 Err(_) => {
262 error!("replication control channel did not respond with certificate.");
263 AdminTaskResponse::Error
264 }
265 }
266}
267
268async fn renew_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
269 let (tx, rx) = oneshot::channel();
270
271 if ctrl_tx
272 .send(ReplCtrl::RenewCertificate { respond: tx })
273 .await
274 .is_err()
275 {
276 error!("replication control channel has shutdown");
277 return AdminTaskResponse::Error;
278 }
279
280 match rx.await {
281 Ok(success) => {
282 if success {
283 show_replication_certificate(ctrl_tx).await
284 } else {
285 error!("replication control channel indicated that certificate renewal failed.");
286 AdminTaskResponse::Error
287 }
288 }
289 Err(_) => {
290 error!("replication control channel did not respond with renewal status.");
291 AdminTaskResponse::Error
292 }
293 }
294}
295
296async fn replication_consumer_refresh(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
297 let (tx, rx) = oneshot::channel();
298
299 if ctrl_tx
300 .send(ReplCtrl::RefreshConsumer { respond: tx })
301 .await
302 .is_err()
303 {
304 error!("replication control channel has shutdown");
305 return AdminTaskResponse::Error;
306 }
307
308 match rx.await {
309 Ok(mut refresh_rx) => {
310 if let Some(()) = refresh_rx.recv().await {
311 info!("Replication refresh success");
312 AdminTaskResponse::Success
313 } else {
314 error!("Replication refresh failed. Please inspect the logs.");
315 AdminTaskResponse::Error
316 }
317 }
318 Err(_) => {
319 error!("replication control channel did not respond with refresh status.");
320 AdminTaskResponse::Error
321 }
322 }
323}
324
325async fn handle_client(
326 sock: UnixStream,
327 server_rw: &'static QueryServerWriteV1,
328 server_ro: &'static QueryServerReadV1,
329 mut repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
330 broadcast_tx: broadcast::Sender<CoreAction>,
331) -> Result<(), Box<dyn Error>> {
332 debug!("Accepted admin socket connection");
333
334 let mut reqs = Framed::new(sock, ServerCodec);
335
336 trace!("Waiting for requests ...");
337 while let Some(Ok(req)) = reqs.next().await {
338 let eventid = Uuid::new_v4();
340 let nspan = span!(Level::INFO, "handle_admin_client_request", uuid = ?eventid);
341
342 let resp = async {
343 match req {
344 AdminTaskRequest::RecoverAccount { name } => {
345 match server_rw.handle_admin_recover_account(name, eventid).await {
346 Ok(password) => AdminTaskResponse::RecoverAccount { password },
347 Err(e) => {
348 error!(err = ?e, "error during recover-account");
349 AdminTaskResponse::Error
350 }
351 }
352 }
353 AdminTaskRequest::DisableAccount { name } => {
354 match server_rw.handle_admin_disable_account(name, eventid).await {
355 Ok(()) => AdminTaskResponse::Success,
356 Err(e) => {
357 error!(err = ?e, "error during disable-account");
358 AdminTaskResponse::Error
359 }
360 }
361 }
362 AdminTaskRequest::ShowReplicationCertificate => match repl_ctrl_tx.as_mut() {
363 Some(ctrl_tx) => show_replication_certificate(ctrl_tx).await,
364 None => {
365 error!("replication not configured, unable to display certificate.");
366 AdminTaskResponse::Error
367 }
368 },
369 AdminTaskRequest::RenewReplicationCertificate => match repl_ctrl_tx.as_mut() {
370 Some(ctrl_tx) => renew_replication_certificate(ctrl_tx).await,
371 None => {
372 error!("replication not configured, unable to renew certificate.");
373 AdminTaskResponse::Error
374 }
375 },
376 AdminTaskRequest::RefreshReplicationConsumer => match repl_ctrl_tx.as_mut() {
377 Some(ctrl_tx) => replication_consumer_refresh(ctrl_tx).await,
378 None => {
379 error!("replication not configured, unable to refresh consumer.");
380 AdminTaskResponse::Error
381 }
382 },
383
384 AdminTaskRequest::DomainShow => match server_ro.handle_domain_show(eventid).await {
385 Ok(domain_info) => AdminTaskResponse::DomainShow { domain_info },
386 Err(e) => {
387 error!(err = ?e, "error during domain show");
388 AdminTaskResponse::Error
389 }
390 },
391 AdminTaskRequest::DomainUpgradeCheck => {
392 match server_ro.handle_domain_upgrade_check(eventid).await {
393 Ok(report) => AdminTaskResponse::DomainUpgradeCheck { report },
394 Err(e) => {
395 error!(err = ?e, "error during domain upgrade checkr");
396 AdminTaskResponse::Error
397 }
398 }
399 }
400 AdminTaskRequest::DomainRaise => match server_rw.handle_domain_raise(eventid).await
401 {
402 Ok(level) => AdminTaskResponse::DomainRaise { level },
403 Err(e) => {
404 error!(err = ?e, "error during domain raise");
405 AdminTaskResponse::Error
406 }
407 },
408 AdminTaskRequest::DomainRemigrate { level } => {
409 match server_rw.handle_domain_remigrate(level, eventid).await {
410 Ok(()) => AdminTaskResponse::Success,
411 Err(e) => {
412 error!(err = ?e, "error during domain remigrate");
413 AdminTaskResponse::Error
414 }
415 }
416 }
417 AdminTaskRequest::Reload => match broadcast_tx.send(CoreAction::Reload) {
418 Ok(_) => AdminTaskResponse::Success,
419 Err(e) => {
420 error!(err = ?e, "error during server reload");
421 AdminTaskResponse::Error
422 }
423 },
424 }
425 }
426 .instrument(nspan)
427 .await;
428
429 reqs.send(resp).await?;
430 reqs.flush().await?;
431 }
432
433 debug!("Disconnecting client ...");
434 Ok(())
435}