kanidm_unix_common/
json_codec.rs

1use crate::constants::{CODEC_BYTESMUT_ALLOCATION_LIMIT, CODEC_MIMIMUM_BYTESMUT_ALLOCATION};
2use bytes::{BufMut, BytesMut};
3use serde::{de::DeserializeOwned, Serialize};
4use std::io;
5use std::marker::PhantomData;
6use tokio_util::codec::{Decoder, Encoder};
7
8const U32_WIDTH: usize = 4;
9
10pub struct JsonCodec<D, E> {
11    phantom_d: PhantomData<D>,
12    phantom_e: PhantomData<E>,
13}
14
15impl<D, E> Default for JsonCodec<D, E> {
16    fn default() -> Self {
17        Self {
18            phantom_d: PhantomData,
19            phantom_e: PhantomData,
20        }
21    }
22}
23
24impl<D, E> Decoder for JsonCodec<D, E>
25where
26    D: DeserializeOwned,
27{
28    type Error = io::Error;
29    type Item = D;
30
31    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
32        if src.len() < U32_WIDTH {
33            // Need more data, at least U32_WIDTH bytes for len
34            return Ok(None);
35        }
36
37        let mut len = [0u8; U32_WIDTH];
38        len.copy_from_slice(&src[0..U32_WIDTH]);
39
40        let len = u32::from_be_bytes(len);
41        let frame_len = U32_WIDTH + len as usize;
42
43        if src.len() < frame_len {
44            // Need more data, at least U32_WIDTH bytes for len, plus the frame size.
45            return Ok(None);
46        }
47
48        // We have the data, lets go.
49        let buffer = src.split_to(frame_len);
50        // This is the frame bytes now.
51        let frame = &buffer[U32_WIDTH..];
52
53        let response = match serde_json::from_slice::<D>(frame) {
54            Ok(msg) => Ok(Some(msg)),
55            Err(json_err) => {
56                error!(?json_err);
57                Err(io::Error::other("Invalid JSON frame"))
58            }
59        };
60
61        // Manage the buffer.
62        if src.is_empty() && src.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
63            trace!("buffer trim");
64            let mut empty = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
65            std::mem::swap(&mut empty, src);
66        }
67
68        response
69    }
70}
71
72impl<D, E> Encoder<E> for JsonCodec<D, E>
73where
74    E: Serialize,
75{
76    type Error = io::Error;
77
78    fn encode(&mut self, msg: E, dst: &mut BytesMut) -> Result<(), Self::Error> {
79        let data = serde_json::to_vec(&msg).map_err(|e| {
80            error!("socket encoding error -> {:?}", e);
81            io::Error::other("JSON encode error")
82        })?;
83
84        // Encode how many bytes we wrote
85        let len = data.len() as u32;
86
87        if len == 0 {
88            warn!("refusing to write empty frame.");
89            return Ok(());
90        }
91
92        dst.put(len.to_be_bytes().as_slice());
93        dst.put(data.as_slice());
94        Ok(())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::{JsonCodec, U32_WIDTH};
101    use bytes::BytesMut;
102    use serde::{Deserialize, Serialize};
103    use tokio_util::codec::{Decoder, Encoder};
104
105    #[derive(Serialize, Deserialize, Debug)]
106    enum Msg {
107        Test,
108    }
109
110    #[test]
111    fn test_json_codec() {
112        let mut codec: JsonCodec<Msg, Msg> = JsonCodec::default();
113        let mut buffer = BytesMut::new();
114
115        // There should be nothing by default
116        let out = codec.decode(&mut buffer);
117        assert!(matches!(out, Ok(None)));
118
119        // Write a frame
120        codec
121            .encode(Msg::Test, &mut buffer)
122            .expect("Failed to encode");
123
124        // Buffer should have bytes.
125        assert_eq!(buffer.len(), U32_WIDTH + 6);
126
127        // Decode
128        let out = codec.decode(&mut buffer);
129        assert!(matches!(out, Ok(Some(Msg::Test))));
130
131        // Buffer should be trimmed.
132        assert_eq!(buffer.len(), 0);
133
134        // Queue up multiple messages.
135        codec
136            .encode(Msg::Test, &mut buffer)
137            .expect("Failed to encode");
138        codec
139            .encode(Msg::Test, &mut buffer)
140            .expect("Failed to encode");
141        codec
142            .encode(Msg::Test, &mut buffer)
143            .expect("Failed to encode");
144
145        assert_eq!(buffer.len(), (U32_WIDTH + 6) * 3);
146
147        // Decode the first
148        let out = codec.decode(&mut buffer);
149        assert!(matches!(out, Ok(Some(Msg::Test))));
150
151        // Do we have more data?
152        assert_eq!(buffer.len(), (U32_WIDTH + 6) * 2);
153
154        // Pull out the rest
155        let out = codec.decode(&mut buffer);
156        assert!(matches!(out, Ok(Some(Msg::Test))));
157        let out = codec.decode(&mut buffer);
158        assert!(matches!(out, Ok(Some(Msg::Test))));
159
160        // Done!
161        assert_eq!(buffer.len(), 0);
162    }
163}