kanidmd_core/repl/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::io;
4use tokio_util::codec::{Decoder, Encoder};
5
6use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
7
8#[derive(Serialize, Deserialize, Debug)]
9pub enum ConsumerRequest {
10    Ping,
11    Incremental(ReplRuvRange),
12    Refresh,
13}
14
15#[derive(Serialize, Deserialize, Debug)]
16pub enum SupplierResponse {
17    Pong,
18    Incremental(ReplIncrementalContext),
19    Refresh(ReplRefreshContext),
20}
21
22#[derive(Default)]
23pub struct ConsumerCodec {
24    max_frame_bytes: usize,
25}
26
27impl ConsumerCodec {
28    pub fn new(max_frame_bytes: usize) -> Self {
29        ConsumerCodec { max_frame_bytes }
30    }
31}
32
33impl Decoder for ConsumerCodec {
34    type Error = io::Error;
35    type Item = SupplierResponse;
36
37    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
38        decode_length_checked_json(self.max_frame_bytes, src)
39    }
40}
41
42impl Encoder<ConsumerRequest> for ConsumerCodec {
43    type Error = io::Error;
44
45    fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
46        encode_length_checked_json(msg, dst)
47    }
48}
49
50#[derive(Default)]
51pub struct SupplierCodec {
52    max_frame_bytes: usize,
53}
54
55impl SupplierCodec {
56    pub fn new(max_frame_bytes: usize) -> Self {
57        SupplierCodec { max_frame_bytes }
58    }
59}
60
61impl Decoder for SupplierCodec {
62    type Error = io::Error;
63    type Item = ConsumerRequest;
64
65    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
66        decode_length_checked_json(self.max_frame_bytes, src)
67    }
68}
69
70impl Encoder<SupplierResponse> for SupplierCodec {
71    type Error = io::Error;
72
73    fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
74        encode_length_checked_json(msg, dst)
75    }
76}
77
78fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
79    // First, if there is anything already in dst, we should split past it.
80    let mut work = dst.split_off(dst.len());
81
82    // Null the head of the buffer.
83    let zero_len = u64::MIN.to_be_bytes();
84    work.extend_from_slice(&zero_len);
85
86    // skip the buffer ahead 8 bytes.
87    // Remember, this split returns the *already set* bytes.
88    // ⚠️  Can't use split or split_at - these return the
89    // len bytes into a new bytes mut which confuses unsplit
90    // by appending the value when we need to append our json.
91    let json_buf = work.split_off(zero_len.len());
92
93    let mut json_writer = json_buf.writer();
94
95    serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
96        error!(?err, "consumer encoding error");
97        io::Error::new(io::ErrorKind::Other, "JSON encode error")
98    })?;
99
100    let json_buf = json_writer.into_inner();
101
102    let final_len = json_buf.len() as u64;
103    let final_len_bytes = final_len.to_be_bytes();
104
105    if final_len_bytes.len() != work.len() {
106        error!("consumer buffer size error");
107        return Err(io::Error::new(io::ErrorKind::Other, "buffer length error"));
108    }
109
110    work.copy_from_slice(&final_len_bytes);
111
112    // Now stitch them back together.
113    work.unsplit(json_buf);
114
115    dst.unsplit(work);
116
117    Ok(())
118}
119
120fn decode_length_checked_json<T: DeserializeOwned>(
121    max_frame_bytes: usize,
122    src: &mut BytesMut,
123) -> Result<Option<T>, io::Error> {
124    trace!(capacity = ?src.capacity());
125
126    if src.len() < 8 {
127        // Not enough for the length header.
128        trace!("Insufficient bytes for length header.");
129        return Ok(None);
130    }
131
132    let (src_len_bytes, json_bytes) = src.split_at(8);
133    let mut len_be_bytes = [0; 8];
134
135    assert_eq!(len_be_bytes.len(), src_len_bytes.len());
136    len_be_bytes.copy_from_slice(src_len_bytes);
137    let req_len = u64::from_be_bytes(len_be_bytes);
138
139    if req_len == 0 {
140        error!("request has size 0");
141        return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
142    }
143
144    if req_len > max_frame_bytes as u64 {
145        error!(
146            "requested decode frame too large {} > {}",
147            req_len, max_frame_bytes
148        );
149        return Err(io::Error::new(
150            io::ErrorKind::OutOfMemory,
151            "request too large",
152        ));
153    }
154
155    if (json_bytes.len() as u64) < req_len {
156        trace!(
157            "Insufficient bytes for json, need: {} have: {}",
158            req_len,
159            src.len()
160        );
161        return Ok(None);
162    }
163
164    // If there are excess bytes, we need to limit our slice to that view.
165    debug_assert!(req_len as usize <= json_bytes.len());
166    let (json_bytes, _remainder) = json_bytes.split_at(req_len as usize);
167
168    // Okay, we have enough. Lets go.
169    let res = serde_json::from_slice(json_bytes)
170        .map(|msg| Some(msg))
171        .map_err(|err| {
172            error!(?err, "received invalid input");
173            io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
174        });
175
176    // Trim to length.
177    if src.len() as u64 == req_len {
178        src.clear();
179    } else {
180        src.advance((8 + req_len) as usize);
181    };
182
183    res
184}
185
186#[cfg(test)]
187mod tests {
188    use bytes::BytesMut;
189    use tokio_util::codec::{Decoder, Encoder};
190
191    use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
192
193    #[test]
194    fn test_repl_codec() {
195        sketching::test_init();
196
197        let mut consumer_codec = ConsumerCodec::new(32);
198
199        let mut buf = BytesMut::with_capacity(32);
200
201        // Empty buffer
202        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
203
204        let zero = [0, 0, 0, 0];
205        buf.extend_from_slice(&zero);
206
207        // Not enough to fill the length header.
208        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
209
210        // Length header reports a zero size request.
211        let zero = [0, 0, 0, 0];
212        buf.extend_from_slice(&zero);
213        assert_eq!(buf.len(), 8);
214        assert!(consumer_codec.decode(&mut buf).is_err());
215
216        // Clear buffer - setup a request with a length > allowed max.
217        buf.clear();
218        let len_bytes = (34_u64).to_be_bytes();
219        buf.extend_from_slice(&len_bytes);
220
221        // Even though the buf len is only 8, this will error as the overall
222        // request will be too large.
223        assert_eq!(buf.len(), 8);
224        assert!(consumer_codec.decode(&mut buf).is_err());
225
226        // Assert that we request more data on a validly sized req
227        buf.clear();
228        let len_bytes = (20_u64).to_be_bytes();
229        buf.extend_from_slice(&len_bytes);
230        // Pad in some extra bytes.
231        buf.extend_from_slice(&zero);
232        assert_eq!(buf.len(), 12);
233        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
234
235        // Make a request that is correctly sized.
236        buf.clear();
237        let mut supplier_codec = SupplierCodec::new(32);
238
239        assert!(consumer_codec
240            .encode(ConsumerRequest::Ping, &mut buf)
241            .is_ok());
242        assert!(matches!(
243            supplier_codec.decode(&mut buf),
244            Ok(Some(ConsumerRequest::Ping))
245        ));
246        // The buf will have been cleared by the supplier codec here.
247        assert!(buf.is_empty());
248        assert!(supplier_codec
249            .encode(SupplierResponse::Pong, &mut buf)
250            .is_ok());
251        assert!(matches!(
252            consumer_codec.decode(&mut buf),
253            Ok(Some(SupplierResponse::Pong))
254        ));
255        assert!(buf.is_empty());
256
257        // Make two requests in a row.
258        buf.clear();
259        let mut supplier_codec = SupplierCodec::new(32);
260
261        assert!(consumer_codec
262            .encode(ConsumerRequest::Ping, &mut buf)
263            .is_ok());
264        assert!(consumer_codec
265            .encode(ConsumerRequest::Ping, &mut buf)
266            .is_ok());
267
268        assert!(matches!(
269            supplier_codec.decode(&mut buf),
270            Ok(Some(ConsumerRequest::Ping))
271        ));
272        assert!(!buf.is_empty());
273        assert!(matches!(
274            supplier_codec.decode(&mut buf),
275            Ok(Some(ConsumerRequest::Ping))
276        ));
277
278        // The buf will have been cleared by the supplier codec here.
279        assert!(buf.is_empty());
280    }
281}