kanidmd_core/repl/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use std::io;
5use tokio_util::codec::{Decoder, Encoder};
6
7// The minimum size of a buffer for the replication codec (1MB)
8pub const CODEC_MIMIMUM_BYTESMUT_ALLOCATION: usize = 1024 * 1024;
9// If the codec buffer exceeds this limit, then we swap the buffer
10// with a fresh one to prevent memory explosions.
11pub const CODEC_BYTESMUT_ALLOCATION_LIMIT: usize = 8 * 1024 * 1024;
12
13#[derive(Serialize, Deserialize, Debug)]
14pub enum ConsumerRequest {
15    Ping,
16    Incremental(ReplRuvRange),
17    Refresh,
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21pub enum SupplierResponse {
22    Pong,
23    Incremental(ReplIncrementalContext),
24    Refresh(ReplRefreshContext),
25}
26
27#[derive(Default)]
28pub struct ConsumerCodec {
29    max_frame_bytes: usize,
30}
31
32impl ConsumerCodec {
33    pub fn new(max_frame_bytes: usize) -> Self {
34        ConsumerCodec { max_frame_bytes }
35    }
36}
37
38impl Decoder for ConsumerCodec {
39    type Error = io::Error;
40    type Item = SupplierResponse;
41
42    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
43        decode_length_checked_json(self.max_frame_bytes, src)
44    }
45}
46
47impl Encoder<ConsumerRequest> for ConsumerCodec {
48    type Error = io::Error;
49
50    fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
51        encode_length_checked_json(msg, dst)
52    }
53}
54
55#[derive(Default)]
56pub struct SupplierCodec {
57    max_frame_bytes: usize,
58}
59
60impl SupplierCodec {
61    pub fn new(max_frame_bytes: usize) -> Self {
62        SupplierCodec { max_frame_bytes }
63    }
64}
65
66impl Decoder for SupplierCodec {
67    type Error = io::Error;
68    type Item = ConsumerRequest;
69
70    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
71        decode_length_checked_json(self.max_frame_bytes, src)
72    }
73}
74
75impl Encoder<SupplierResponse> for SupplierCodec {
76    type Error = io::Error;
77
78    fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
79        encode_length_checked_json(msg, dst)
80    }
81}
82
83fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
84    // If the outgoing buffer is empty AND greater than our allocation limit, we
85    // want to attempt to free space.
86    if dst.is_empty() && dst.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
87        dst.clear();
88        let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
89        std::mem::swap(&mut buf, dst);
90    }
91
92    // First, if there is anything already in dst, we should split past it.
93    let mut work = dst.split_off(dst.len());
94
95    // Null the head of the buffer.
96    let zero_len = u64::MIN.to_be_bytes();
97    work.extend_from_slice(&zero_len);
98
99    // skip the buffer ahead 8 bytes.
100    // Remember, this split returns the *already set* bytes.
101    // ⚠️  Can't use split or split_at - these return the
102    // len bytes into a new bytes mut which confuses unsplit
103    // by appending the value when we need to append our json.
104    let json_buf = work.split_off(zero_len.len());
105
106    let mut json_writer = json_buf.writer();
107
108    serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
109        error!(?err, "consumer encoding error");
110        io::Error::other("JSON encode error")
111    })?;
112
113    let json_buf = json_writer.into_inner();
114
115    let final_len = json_buf.len() as u64;
116    let final_len_bytes = final_len.to_be_bytes();
117
118    if final_len_bytes.len() != work.len() {
119        error!("consumer buffer size error");
120        return Err(io::Error::other("buffer length error"));
121    }
122
123    work.copy_from_slice(&final_len_bytes);
124
125    // Now stitch them back together.
126    work.unsplit(json_buf);
127
128    dst.unsplit(work);
129
130    Ok(())
131}
132
133fn decode_length_checked_json<T: DeserializeOwned>(
134    max_frame_bytes: usize,
135    src: &mut BytesMut,
136) -> Result<Option<T>, io::Error> {
137    trace!(capacity = ?src.capacity());
138
139    if src.len() < 8 {
140        // Not enough for the length header.
141        trace!("Insufficient bytes for length header.");
142        return Ok(None);
143    }
144
145    let (src_len_bytes, json_bytes) = src.split_at(8);
146    let mut len_be_bytes = [0; 8];
147
148    assert_eq!(len_be_bytes.len(), src_len_bytes.len());
149    len_be_bytes.copy_from_slice(src_len_bytes);
150    let req_len = u64::from_be_bytes(len_be_bytes);
151
152    if req_len == 0 {
153        error!("request has size 0");
154        return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
155    }
156
157    if req_len > max_frame_bytes as u64 {
158        error!(
159            "requested decode frame too large {} > {}",
160            req_len, max_frame_bytes
161        );
162        return Err(io::Error::new(
163            io::ErrorKind::OutOfMemory,
164            "request too large",
165        ));
166    }
167
168    if (json_bytes.len() as u64) < req_len {
169        trace!(
170            "Insufficient bytes for json, need: {} have: {}",
171            req_len,
172            src.len()
173        );
174        return Ok(None);
175    }
176
177    // If there are excess bytes, we need to limit our slice to that view.
178    debug_assert!(req_len as usize <= json_bytes.len());
179    let (json_bytes, _remainder) = json_bytes.split_at(req_len as usize);
180
181    // Okay, we have enough. Lets go.
182    let res = serde_json::from_slice(json_bytes)
183        .map(|msg| Some(msg))
184        .map_err(|err| {
185            error!(?err, "received invalid input");
186            io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
187        });
188
189    // Trim to length.
190    if src.len() as u64 == req_len {
191        src.clear();
192        if src.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
193            let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
194            std::mem::swap(&mut buf, src);
195        }
196    } else {
197        src.advance((8 + req_len) as usize);
198    };
199
200    res
201}
202
203#[cfg(test)]
204mod tests {
205    use bytes::BytesMut;
206    use tokio_util::codec::{Decoder, Encoder};
207
208    use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
209
210    #[test]
211    fn test_repl_codec() {
212        sketching::test_init();
213
214        let mut consumer_codec = ConsumerCodec::new(32);
215
216        let mut buf = BytesMut::with_capacity(32);
217
218        // Empty buffer
219        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
220
221        let zero = [0, 0, 0, 0];
222        buf.extend_from_slice(&zero);
223
224        // Not enough to fill the length header.
225        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
226
227        // Length header reports a zero size request.
228        let zero = [0, 0, 0, 0];
229        buf.extend_from_slice(&zero);
230        assert_eq!(buf.len(), 8);
231        assert!(consumer_codec.decode(&mut buf).is_err());
232
233        // Clear buffer - setup a request with a length > allowed max.
234        buf.clear();
235        let len_bytes = (34_u64).to_be_bytes();
236        buf.extend_from_slice(&len_bytes);
237
238        // Even though the buf len is only 8, this will error as the overall
239        // request will be too large.
240        assert_eq!(buf.len(), 8);
241        assert!(consumer_codec.decode(&mut buf).is_err());
242
243        // Assert that we request more data on a validly sized req
244        buf.clear();
245        let len_bytes = (20_u64).to_be_bytes();
246        buf.extend_from_slice(&len_bytes);
247        // Pad in some extra bytes.
248        buf.extend_from_slice(&zero);
249        assert_eq!(buf.len(), 12);
250        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
251
252        // Make a request that is correctly sized.
253        buf.clear();
254        let mut supplier_codec = SupplierCodec::new(32);
255
256        assert!(consumer_codec
257            .encode(ConsumerRequest::Ping, &mut buf)
258            .is_ok());
259        assert!(matches!(
260            supplier_codec.decode(&mut buf),
261            Ok(Some(ConsumerRequest::Ping))
262        ));
263        // The buf will have been cleared by the supplier codec here.
264        assert!(buf.is_empty());
265        assert!(supplier_codec
266            .encode(SupplierResponse::Pong, &mut buf)
267            .is_ok());
268        assert!(matches!(
269            consumer_codec.decode(&mut buf),
270            Ok(Some(SupplierResponse::Pong))
271        ));
272        assert!(buf.is_empty());
273
274        // Make two requests in a row.
275        buf.clear();
276        let mut supplier_codec = SupplierCodec::new(32);
277
278        assert!(consumer_codec
279            .encode(ConsumerRequest::Ping, &mut buf)
280            .is_ok());
281        assert!(consumer_codec
282            .encode(ConsumerRequest::Ping, &mut buf)
283            .is_ok());
284
285        assert!(matches!(
286            supplier_codec.decode(&mut buf),
287            Ok(Some(ConsumerRequest::Ping))
288        ));
289        assert!(!buf.is_empty());
290        assert!(matches!(
291            supplier_codec.decode(&mut buf),
292            Ok(Some(ConsumerRequest::Ping))
293        ));
294
295        // The buf will have been cleared by the supplier codec here.
296        assert!(buf.is_empty());
297    }
298}