kanidmd_core/repl/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::io;
4use tokio_util::codec::{Decoder, Encoder};
5
6use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
7
8// The minimum size of a buffer for the replication codec (1MB)
9pub const CODEC_MIMIMUM_BYTESMUT_ALLOCATION: usize = 1024 * 1024;
10// If the codec buffer exceeds this limit, then we swap the buffer
11// with a fresh one to prevent memory explosions.
12pub const CODEC_BYTESMUT_ALLOCATION_LIMIT: usize = 8 * 1024 * 1024;
13
14#[derive(Serialize, Deserialize, Debug)]
15pub enum ConsumerRequest {
16    Ping,
17    Incremental(ReplRuvRange),
18    Refresh,
19}
20
21#[derive(Serialize, Deserialize, Debug)]
22pub enum SupplierResponse {
23    Pong,
24    Incremental(ReplIncrementalContext),
25    Refresh(ReplRefreshContext),
26}
27
28#[derive(Default)]
29pub struct ConsumerCodec {
30    max_frame_bytes: usize,
31}
32
33impl ConsumerCodec {
34    pub fn new(max_frame_bytes: usize) -> Self {
35        ConsumerCodec { max_frame_bytes }
36    }
37}
38
39impl Decoder for ConsumerCodec {
40    type Error = io::Error;
41    type Item = SupplierResponse;
42
43    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
44        decode_length_checked_json(self.max_frame_bytes, src)
45    }
46}
47
48impl Encoder<ConsumerRequest> for ConsumerCodec {
49    type Error = io::Error;
50
51    fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
52        encode_length_checked_json(msg, dst)
53    }
54}
55
56#[derive(Default)]
57pub struct SupplierCodec {
58    max_frame_bytes: usize,
59}
60
61impl SupplierCodec {
62    pub fn new(max_frame_bytes: usize) -> Self {
63        SupplierCodec { max_frame_bytes }
64    }
65}
66
67impl Decoder for SupplierCodec {
68    type Error = io::Error;
69    type Item = ConsumerRequest;
70
71    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
72        decode_length_checked_json(self.max_frame_bytes, src)
73    }
74}
75
76impl Encoder<SupplierResponse> for SupplierCodec {
77    type Error = io::Error;
78
79    fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
80        encode_length_checked_json(msg, dst)
81    }
82}
83
84fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
85    // If the outgoing buffer is empty AND greater than our allocation limit, we
86    // want to attempt to free space.
87    if dst.is_empty() && dst.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
88        dst.clear();
89        let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
90        std::mem::swap(&mut buf, dst);
91    }
92
93    // First, if there is anything already in dst, we should split past it.
94    let mut work = dst.split_off(dst.len());
95
96    // Null the head of the buffer.
97    let zero_len = u64::MIN.to_be_bytes();
98    work.extend_from_slice(&zero_len);
99
100    // skip the buffer ahead 8 bytes.
101    // Remember, this split returns the *already set* bytes.
102    // ⚠️  Can't use split or split_at - these return the
103    // len bytes into a new bytes mut which confuses unsplit
104    // by appending the value when we need to append our json.
105    let json_buf = work.split_off(zero_len.len());
106
107    let mut json_writer = json_buf.writer();
108
109    serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
110        error!(?err, "consumer encoding error");
111        io::Error::other("JSON encode error")
112    })?;
113
114    let json_buf = json_writer.into_inner();
115
116    let final_len = json_buf.len() as u64;
117    let final_len_bytes = final_len.to_be_bytes();
118
119    if final_len_bytes.len() != work.len() {
120        error!("consumer buffer size error");
121        return Err(io::Error::other("buffer length error"));
122    }
123
124    work.copy_from_slice(&final_len_bytes);
125
126    // Now stitch them back together.
127    work.unsplit(json_buf);
128
129    dst.unsplit(work);
130
131    Ok(())
132}
133
134fn decode_length_checked_json<T: DeserializeOwned>(
135    max_frame_bytes: usize,
136    src: &mut BytesMut,
137) -> Result<Option<T>, io::Error> {
138    trace!(capacity = ?src.capacity());
139
140    if src.len() < 8 {
141        // Not enough for the length header.
142        trace!("Insufficient bytes for length header.");
143        return Ok(None);
144    }
145
146    let (src_len_bytes, json_bytes) = src.split_at(8);
147    let mut len_be_bytes = [0; 8];
148
149    assert_eq!(len_be_bytes.len(), src_len_bytes.len());
150    len_be_bytes.copy_from_slice(src_len_bytes);
151    let req_len = u64::from_be_bytes(len_be_bytes);
152
153    if req_len == 0 {
154        error!("request has size 0");
155        return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
156    }
157
158    if req_len > max_frame_bytes as u64 {
159        error!(
160            "requested decode frame too large {} > {}",
161            req_len, max_frame_bytes
162        );
163        return Err(io::Error::new(
164            io::ErrorKind::OutOfMemory,
165            "request too large",
166        ));
167    }
168
169    if (json_bytes.len() as u64) < req_len {
170        trace!(
171            "Insufficient bytes for json, need: {} have: {}",
172            req_len,
173            src.len()
174        );
175        return Ok(None);
176    }
177
178    // If there are excess bytes, we need to limit our slice to that view.
179    debug_assert!(req_len as usize <= json_bytes.len());
180    let (json_bytes, _remainder) = json_bytes.split_at(req_len as usize);
181
182    // Okay, we have enough. Lets go.
183    let res = serde_json::from_slice(json_bytes)
184        .map(|msg| Some(msg))
185        .map_err(|err| {
186            error!(?err, "received invalid input");
187            io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
188        });
189
190    // Trim to length.
191    if src.len() as u64 == req_len {
192        src.clear();
193        if src.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
194            let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
195            std::mem::swap(&mut buf, src);
196        }
197    } else {
198        src.advance((8 + req_len) as usize);
199    };
200
201    res
202}
203
204#[cfg(test)]
205mod tests {
206    use bytes::BytesMut;
207    use tokio_util::codec::{Decoder, Encoder};
208
209    use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
210
211    #[test]
212    fn test_repl_codec() {
213        sketching::test_init();
214
215        let mut consumer_codec = ConsumerCodec::new(32);
216
217        let mut buf = BytesMut::with_capacity(32);
218
219        // Empty buffer
220        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
221
222        let zero = [0, 0, 0, 0];
223        buf.extend_from_slice(&zero);
224
225        // Not enough to fill the length header.
226        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
227
228        // Length header reports a zero size request.
229        let zero = [0, 0, 0, 0];
230        buf.extend_from_slice(&zero);
231        assert_eq!(buf.len(), 8);
232        assert!(consumer_codec.decode(&mut buf).is_err());
233
234        // Clear buffer - setup a request with a length > allowed max.
235        buf.clear();
236        let len_bytes = (34_u64).to_be_bytes();
237        buf.extend_from_slice(&len_bytes);
238
239        // Even though the buf len is only 8, this will error as the overall
240        // request will be too large.
241        assert_eq!(buf.len(), 8);
242        assert!(consumer_codec.decode(&mut buf).is_err());
243
244        // Assert that we request more data on a validly sized req
245        buf.clear();
246        let len_bytes = (20_u64).to_be_bytes();
247        buf.extend_from_slice(&len_bytes);
248        // Pad in some extra bytes.
249        buf.extend_from_slice(&zero);
250        assert_eq!(buf.len(), 12);
251        assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
252
253        // Make a request that is correctly sized.
254        buf.clear();
255        let mut supplier_codec = SupplierCodec::new(32);
256
257        assert!(consumer_codec
258            .encode(ConsumerRequest::Ping, &mut buf)
259            .is_ok());
260        assert!(matches!(
261            supplier_codec.decode(&mut buf),
262            Ok(Some(ConsumerRequest::Ping))
263        ));
264        // The buf will have been cleared by the supplier codec here.
265        assert!(buf.is_empty());
266        assert!(supplier_codec
267            .encode(SupplierResponse::Pong, &mut buf)
268            .is_ok());
269        assert!(matches!(
270            consumer_codec.decode(&mut buf),
271            Ok(Some(SupplierResponse::Pong))
272        ));
273        assert!(buf.is_empty());
274
275        // Make two requests in a row.
276        buf.clear();
277        let mut supplier_codec = SupplierCodec::new(32);
278
279        assert!(consumer_codec
280            .encode(ConsumerRequest::Ping, &mut buf)
281            .is_ok());
282        assert!(consumer_codec
283            .encode(ConsumerRequest::Ping, &mut buf)
284            .is_ok());
285
286        assert!(matches!(
287            supplier_codec.decode(&mut buf),
288            Ok(Some(ConsumerRequest::Ping))
289        ));
290        assert!(!buf.is_empty());
291        assert!(matches!(
292            supplier_codec.decode(&mut buf),
293            Ok(Some(ConsumerRequest::Ping))
294        ));
295
296        // The buf will have been cleared by the supplier codec here.
297        assert!(buf.is_empty());
298    }
299}