kanidm_unix_common/
json_codec.rs1use crate::constants::{CODEC_BYTESMUT_ALLOCATION_LIMIT, CODEC_MIMIMUM_BYTESMUT_ALLOCATION};
2use bytes::{BufMut, BytesMut};
3use serde::{de::DeserializeOwned, Serialize};
4use std::io;
5use std::marker::PhantomData;
6use tokio_util::codec::{Decoder, Encoder};
7
8const U32_WIDTH: usize = 4;
9
10pub struct JsonCodec<D, E> {
11 phantom_d: PhantomData<D>,
12 phantom_e: PhantomData<E>,
13}
14
15impl<D, E> Default for JsonCodec<D, E> {
16 fn default() -> Self {
17 Self {
18 phantom_d: PhantomData,
19 phantom_e: PhantomData,
20 }
21 }
22}
23
24impl<D, E> Decoder for JsonCodec<D, E>
25where
26 D: DeserializeOwned,
27{
28 type Error = io::Error;
29 type Item = D;
30
31 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
32 if src.len() < U32_WIDTH {
33 return Ok(None);
35 }
36
37 let mut len = [0u8; U32_WIDTH];
38 len.copy_from_slice(&src[0..U32_WIDTH]);
39
40 let len = u32::from_be_bytes(len);
41 let frame_len = U32_WIDTH + len as usize;
42
43 if src.len() < frame_len {
44 return Ok(None);
46 }
47
48 let buffer = src.split_to(frame_len);
50 let frame = &buffer[U32_WIDTH..];
52
53 let response = match serde_json::from_slice::<D>(frame) {
54 Ok(msg) => Ok(Some(msg)),
55 Err(json_err) => {
56 error!(?json_err);
57 Err(io::Error::other("Invalid JSON frame"))
58 }
59 };
60
61 if src.is_empty() && src.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
63 trace!("buffer trim");
64 let mut empty = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
65 std::mem::swap(&mut empty, src);
66 }
67
68 response
69 }
70}
71
72impl<D, E> Encoder<E> for JsonCodec<D, E>
73where
74 E: Serialize,
75{
76 type Error = io::Error;
77
78 fn encode(&mut self, msg: E, dst: &mut BytesMut) -> Result<(), Self::Error> {
79 let data = serde_json::to_vec(&msg).map_err(|e| {
80 error!("socket encoding error -> {:?}", e);
81 io::Error::other("JSON encode error")
82 })?;
83
84 let len = data.len() as u32;
86
87 if len == 0 {
88 warn!("refusing to write empty frame.");
89 return Ok(());
90 }
91
92 dst.put(len.to_be_bytes().as_slice());
93 dst.put(data.as_slice());
94 Ok(())
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::{JsonCodec, U32_WIDTH};
101 use bytes::BytesMut;
102 use serde::{Deserialize, Serialize};
103 use tokio_util::codec::{Decoder, Encoder};
104
105 #[derive(Serialize, Deserialize, Debug)]
106 enum Msg {
107 Test,
108 }
109
110 #[test]
111 fn test_json_codec() {
112 let mut codec: JsonCodec<Msg, Msg> = JsonCodec::default();
113 let mut buffer = BytesMut::new();
114
115 let out = codec.decode(&mut buffer);
117 assert!(matches!(out, Ok(None)));
118
119 codec
121 .encode(Msg::Test, &mut buffer)
122 .expect("Failed to encode");
123
124 assert_eq!(buffer.len(), U32_WIDTH + 6);
126
127 let out = codec.decode(&mut buffer);
129 assert!(matches!(out, Ok(Some(Msg::Test))));
130
131 assert_eq!(buffer.len(), 0);
133
134 codec
136 .encode(Msg::Test, &mut buffer)
137 .expect("Failed to encode");
138 codec
139 .encode(Msg::Test, &mut buffer)
140 .expect("Failed to encode");
141 codec
142 .encode(Msg::Test, &mut buffer)
143 .expect("Failed to encode");
144
145 assert_eq!(buffer.len(), (U32_WIDTH + 6) * 3);
146
147 let out = codec.decode(&mut buffer);
149 assert!(matches!(out, Ok(Some(Msg::Test))));
150
151 assert_eq!(buffer.len(), (U32_WIDTH + 6) * 2);
153
154 let out = codec.decode(&mut buffer);
156 assert!(matches!(out, Ok(Some(Msg::Test))));
157 let out = codec.decode(&mut buffer);
158 assert!(matches!(out, Ok(Some(Msg::Test))));
159
160 assert_eq!(buffer.len(), 0);
162 }
163}