kanidmd_core/repl/
codec.rs
1use bytes::{Buf, BufMut, BytesMut};
2use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use std::io;
5use tokio_util::codec::{Decoder, Encoder};
6
7pub const CODEC_MIMIMUM_BYTESMUT_ALLOCATION: usize = 1024 * 1024;
9pub const CODEC_BYTESMUT_ALLOCATION_LIMIT: usize = 8 * 1024 * 1024;
12
13#[derive(Serialize, Deserialize, Debug)]
14pub enum ConsumerRequest {
15 Ping,
16 Incremental(ReplRuvRange),
17 Refresh,
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21pub enum SupplierResponse {
22 Pong,
23 Incremental(ReplIncrementalContext),
24 Refresh(ReplRefreshContext),
25}
26
27#[derive(Default)]
28pub struct ConsumerCodec {
29 max_frame_bytes: usize,
30}
31
32impl ConsumerCodec {
33 pub fn new(max_frame_bytes: usize) -> Self {
34 ConsumerCodec { max_frame_bytes }
35 }
36}
37
38impl Decoder for ConsumerCodec {
39 type Error = io::Error;
40 type Item = SupplierResponse;
41
42 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
43 decode_length_checked_json(self.max_frame_bytes, src)
44 }
45}
46
47impl Encoder<ConsumerRequest> for ConsumerCodec {
48 type Error = io::Error;
49
50 fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
51 encode_length_checked_json(msg, dst)
52 }
53}
54
55#[derive(Default)]
56pub struct SupplierCodec {
57 max_frame_bytes: usize,
58}
59
60impl SupplierCodec {
61 pub fn new(max_frame_bytes: usize) -> Self {
62 SupplierCodec { max_frame_bytes }
63 }
64}
65
66impl Decoder for SupplierCodec {
67 type Error = io::Error;
68 type Item = ConsumerRequest;
69
70 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
71 decode_length_checked_json(self.max_frame_bytes, src)
72 }
73}
74
75impl Encoder<SupplierResponse> for SupplierCodec {
76 type Error = io::Error;
77
78 fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
79 encode_length_checked_json(msg, dst)
80 }
81}
82
83fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
84 if dst.is_empty() && dst.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
87 dst.clear();
88 let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
89 std::mem::swap(&mut buf, dst);
90 }
91
92 let mut work = dst.split_off(dst.len());
94
95 let zero_len = u64::MIN.to_be_bytes();
97 work.extend_from_slice(&zero_len);
98
99 let json_buf = work.split_off(zero_len.len());
105
106 let mut json_writer = json_buf.writer();
107
108 serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
109 error!(?err, "consumer encoding error");
110 io::Error::other("JSON encode error")
111 })?;
112
113 let json_buf = json_writer.into_inner();
114
115 let final_len = json_buf.len() as u64;
116 let final_len_bytes = final_len.to_be_bytes();
117
118 if final_len_bytes.len() != work.len() {
119 error!("consumer buffer size error");
120 return Err(io::Error::other("buffer length error"));
121 }
122
123 work.copy_from_slice(&final_len_bytes);
124
125 work.unsplit(json_buf);
127
128 dst.unsplit(work);
129
130 Ok(())
131}
132
133fn decode_length_checked_json<T: DeserializeOwned>(
134 max_frame_bytes: usize,
135 src: &mut BytesMut,
136) -> Result<Option<T>, io::Error> {
137 trace!(capacity = ?src.capacity());
138
139 if src.len() < 8 {
140 trace!("Insufficient bytes for length header.");
142 return Ok(None);
143 }
144
145 let (src_len_bytes, json_bytes) = src.split_at(8);
146 let mut len_be_bytes = [0; 8];
147
148 assert_eq!(len_be_bytes.len(), src_len_bytes.len());
149 len_be_bytes.copy_from_slice(src_len_bytes);
150 let req_len = u64::from_be_bytes(len_be_bytes);
151
152 if req_len == 0 {
153 error!("request has size 0");
154 return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
155 }
156
157 if req_len > max_frame_bytes as u64 {
158 error!(
159 "requested decode frame too large {} > {}",
160 req_len, max_frame_bytes
161 );
162 return Err(io::Error::new(
163 io::ErrorKind::OutOfMemory,
164 "request too large",
165 ));
166 }
167
168 if (json_bytes.len() as u64) < req_len {
169 trace!(
170 "Insufficient bytes for json, need: {} have: {}",
171 req_len,
172 src.len()
173 );
174 return Ok(None);
175 }
176
177 debug_assert!(req_len as usize <= json_bytes.len());
179 let (json_bytes, _remainder) = json_bytes.split_at(req_len as usize);
180
181 let res = serde_json::from_slice(json_bytes)
183 .map(|msg| Some(msg))
184 .map_err(|err| {
185 error!(?err, "received invalid input");
186 io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
187 });
188
189 if src.len() as u64 == req_len {
191 src.clear();
192 if src.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
193 let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
194 std::mem::swap(&mut buf, src);
195 }
196 } else {
197 src.advance((8 + req_len) as usize);
198 };
199
200 res
201}
202
203#[cfg(test)]
204mod tests {
205 use bytes::BytesMut;
206 use tokio_util::codec::{Decoder, Encoder};
207
208 use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
209
210 #[test]
211 fn test_repl_codec() {
212 sketching::test_init();
213
214 let mut consumer_codec = ConsumerCodec::new(32);
215
216 let mut buf = BytesMut::with_capacity(32);
217
218 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
220
221 let zero = [0, 0, 0, 0];
222 buf.extend_from_slice(&zero);
223
224 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
226
227 let zero = [0, 0, 0, 0];
229 buf.extend_from_slice(&zero);
230 assert_eq!(buf.len(), 8);
231 assert!(consumer_codec.decode(&mut buf).is_err());
232
233 buf.clear();
235 let len_bytes = (34_u64).to_be_bytes();
236 buf.extend_from_slice(&len_bytes);
237
238 assert_eq!(buf.len(), 8);
241 assert!(consumer_codec.decode(&mut buf).is_err());
242
243 buf.clear();
245 let len_bytes = (20_u64).to_be_bytes();
246 buf.extend_from_slice(&len_bytes);
247 buf.extend_from_slice(&zero);
249 assert_eq!(buf.len(), 12);
250 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
251
252 buf.clear();
254 let mut supplier_codec = SupplierCodec::new(32);
255
256 assert!(consumer_codec
257 .encode(ConsumerRequest::Ping, &mut buf)
258 .is_ok());
259 assert!(matches!(
260 supplier_codec.decode(&mut buf),
261 Ok(Some(ConsumerRequest::Ping))
262 ));
263 assert!(buf.is_empty());
265 assert!(supplier_codec
266 .encode(SupplierResponse::Pong, &mut buf)
267 .is_ok());
268 assert!(matches!(
269 consumer_codec.decode(&mut buf),
270 Ok(Some(SupplierResponse::Pong))
271 ));
272 assert!(buf.is_empty());
273
274 buf.clear();
276 let mut supplier_codec = SupplierCodec::new(32);
277
278 assert!(consumer_codec
279 .encode(ConsumerRequest::Ping, &mut buf)
280 .is_ok());
281 assert!(consumer_codec
282 .encode(ConsumerRequest::Ping, &mut buf)
283 .is_ok());
284
285 assert!(matches!(
286 supplier_codec.decode(&mut buf),
287 Ok(Some(ConsumerRequest::Ping))
288 ));
289 assert!(!buf.is_empty());
290 assert!(matches!(
291 supplier_codec.decode(&mut buf),
292 Ok(Some(ConsumerRequest::Ping))
293 ));
294
295 assert!(buf.is_empty());
297 }
298}