kanidm_unix_common/
client_sync.rs

1use crate::constants::DEFAULT_CONN_TIMEOUT;
2use crate::json_codec::JsonCodec;
3use crate::unix_proto::{ClientRequest, ClientResponse};
4use bytes::BytesMut;
5use std::error::Error;
6use std::io::{self, ErrorKind, Read, Write};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use tokio_util::codec::{Decoder, Encoder};
10
11pub use std::os::unix::net::UnixStream;
12
13type ClientCodec = JsonCodec<ClientResponse, ClientRequest>;
14
15#[derive(Clone)]
16pub struct DaemonClientBlocking {
17    inner: Arc<Mutex<DaemonClientBlockingInner>>,
18}
19
20struct DaemonClientBlockingInner {
21    stream: UnixStream,
22    codec: ClientCodec,
23    default_timeout: u64,
24    reconnect: bool,
25}
26
27impl From<UnixStream> for DaemonClientBlocking {
28    fn from(stream: UnixStream) -> Self {
29        DaemonClientBlocking {
30            inner: Arc::new(Mutex::new(DaemonClientBlockingInner {
31                stream,
32                codec: ClientCodec::default(),
33                default_timeout: DEFAULT_CONN_TIMEOUT,
34                reconnect: false,
35            })),
36        }
37    }
38}
39
40impl DaemonClientBlocking {
41    pub fn new(path: &str, default_timeout: u64) -> Result<DaemonClientBlocking, Box<dyn Error>> {
42        // Setup a subscriber incase one isn't setup.
43        if cfg!(feature = "client_sync_tracing") {
44            use tracing_subscriber::prelude::*;
45            use tracing_subscriber::{filter::LevelFilter, fmt};
46
47            let fmt_layer = fmt::layer().with_target(false);
48            let filter_layer = LevelFilter::WARN;
49
50            let _ = tracing_subscriber::registry()
51                .with(filter_layer)
52                .with(fmt_layer)
53                .try_init();
54        }
55
56        trace!(%path);
57
58        let stream = UnixStream::connect(path).map_err(|err| {
59            error!(
60                ?err, %path,
61                "Unix socket stream setup error",
62            );
63            Box::new(err)
64        })?;
65
66        Ok(DaemonClientBlocking {
67            inner: Arc::new(Mutex::new(DaemonClientBlockingInner {
68                stream,
69                codec: ClientCodec::default(),
70                default_timeout,
71                reconnect: false,
72            })),
73        })
74    }
75
76    pub fn call_and_wait(
77        &self,
78        req: ClientRequest,
79        timeout: Option<u64>,
80    ) -> Result<ClientResponse, Box<dyn Error + '_>> {
81        let mut guard = self.inner.lock().map_err(|err| {
82            error!(?err, "critical, daemon client mutex has been poisoned!!!");
83            Box::new(err)
84        })?;
85
86        if guard.reconnect {
87            let peer_addr = guard.stream.peer_addr().map_err(|err| {
88                error!(
89                    ?err,
90                    "critical, stream has no peer address, unable to reconnect!!!"
91                );
92                Box::new(err)
93            })?;
94
95            let mut new_stream = UnixStream::connect_addr(&peer_addr).map_err(|err| {
96                error!(?err, ?peer_addr, "Unix socket stream setup error",);
97                Box::new(err)
98            })?;
99
100            debug!("Reconnection complete.");
101
102            std::mem::swap(&mut guard.stream, &mut new_stream);
103            guard.reconnect = false;
104        }
105
106        guard.call_and_wait(req, timeout).inspect_err(|_| {
107            debug!("error occured during communication, will reconnect ...");
108            guard.reconnect = true;
109        })
110    }
111}
112
113impl DaemonClientBlockingInner {
114    fn call_and_wait(
115        &mut self,
116        req: ClientRequest,
117        timeout: Option<u64>,
118    ) -> Result<ClientResponse, Box<dyn Error>> {
119        let timeout = Duration::from_secs(timeout.unwrap_or(self.default_timeout));
120
121        self.stream
122            .set_write_timeout(Some(timeout))
123            .map_err(|err| {
124                error!(
125                    ?err,
126                    "Unix socket stream setup error while setting write timeout",
127                );
128                Box::new(err)
129            })?;
130
131        // We want this to be blocking so that we wait for data to be ready
132        self.stream.set_nonblocking(false).map_err(|err| {
133            error!(
134                ?err,
135                "Unix socket stream setup error while setting nonblocking=false",
136            );
137            Box::new(err)
138        })?;
139
140        let mut data = BytesMut::new();
141
142        self.codec.encode(req, &mut data).map_err(|err| {
143            error!(?err, "codec encode error");
144            Box::new(err)
145        })?;
146
147        self.stream
148            .write_all(&data)
149            .and_then(|_| self.stream.flush())
150            .map_err(|err| {
151                error!(?err, "stream write error");
152                Box::new(err)
153            })?;
154
155        // Set our read timeout
156        self.stream.set_read_timeout(Some(timeout)).map_err(|err| {
157            error!(
158                ?err,
159                "Unix socket stream setup error while setting read timeout",
160            );
161            Box::new(err)
162        })?;
163
164        // We want this to be blocking so that we wait for data to be ready
165        self.stream.set_nonblocking(false).map_err(|err| {
166            error!(
167                ?err,
168                "Unix socket stream setup error while setting nonblocking=false",
169            );
170            Box::new(err)
171        })?;
172
173        trace!(read_timeout = ?self.stream.read_timeout(), write_timeout = ?self.stream.write_timeout());
174
175        // Now wait on the response.
176        data.clear();
177        let start = Instant::now();
178        let mut read_started = false;
179
180        loop {
181            trace!("read loop");
182            let durr = Instant::now().duration_since(start);
183            if durr > timeout {
184                error!("Socket timeout");
185                // timed out, not enough activity.
186                return Err(Box::new(io::Error::other("Timeout")));
187            }
188
189            let mut buffer = [0; 16 * 1024];
190
191            // Would be a lot easier if we had peek ...
192            // https://github.com/rust-lang/rust/issues/76923
193            match self.stream.read(&mut buffer) {
194                Ok(0) => {
195                    if read_started {
196                        trace!("read_started true, no bytes read");
197                        // We're done, no more bytes. This will now
198                        // fall through to the codec decode to double
199                        // check this assertion.
200                    } else {
201                        trace!("Waiting ...");
202                        // Still can wait ...
203                        continue;
204                    }
205                }
206                Ok(count) => {
207                    read_started = true;
208                    trace!("read {count} bytes");
209                    data.extend_from_slice(&buffer[..count]);
210                    if count == buffer.len() {
211                        // Whole buffer, read again
212                        continue;
213                    }
214                    // Not a whole buffer, probably complete.
215                }
216                Err(err) if err.kind() == ErrorKind::WouldBlock => {
217                    warn!("read from UDS would block, try again.");
218                    // std::thread::sleep(Duration::from_millis(1));
219                    std::thread::yield_now();
220                    continue;
221                }
222                Err(err) => {
223                    error!(?err, err_kind = ?err.kind(), "Stream read failure from {:?}", &self.stream);
224                    // Failure!
225                    return Err(Box::new(err));
226                }
227            }
228
229            match self.codec.decode(&mut data) {
230                // A whole frame is ready and present.
231                Ok(Some(cr)) => {
232                    trace!("read loop - ok");
233                    return Ok(cr);
234                }
235                // Need more data
236                Ok(None) => {
237                    trace!("need more");
238                    continue;
239                }
240                // Failed to decode for some reason
241                Err(err) => {
242                    error!(?err, "failed to decode response");
243                    return Err(Box::new(err));
244                }
245            }
246        }
247    }
248}