kanidmd_core/repl/
codec.rs
1use bytes::{Buf, BufMut, BytesMut};
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::io;
4use tokio_util::codec::{Decoder, Encoder};
5
6use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
7
8#[derive(Serialize, Deserialize, Debug)]
9pub enum ConsumerRequest {
10 Ping,
11 Incremental(ReplRuvRange),
12 Refresh,
13}
14
15#[derive(Serialize, Deserialize, Debug)]
16pub enum SupplierResponse {
17 Pong,
18 Incremental(ReplIncrementalContext),
19 Refresh(ReplRefreshContext),
20}
21
22#[derive(Default)]
23pub struct ConsumerCodec {
24 max_frame_bytes: usize,
25}
26
27impl ConsumerCodec {
28 pub fn new(max_frame_bytes: usize) -> Self {
29 ConsumerCodec { max_frame_bytes }
30 }
31}
32
33impl Decoder for ConsumerCodec {
34 type Error = io::Error;
35 type Item = SupplierResponse;
36
37 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
38 decode_length_checked_json(self.max_frame_bytes, src)
39 }
40}
41
42impl Encoder<ConsumerRequest> for ConsumerCodec {
43 type Error = io::Error;
44
45 fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
46 encode_length_checked_json(msg, dst)
47 }
48}
49
50#[derive(Default)]
51pub struct SupplierCodec {
52 max_frame_bytes: usize,
53}
54
55impl SupplierCodec {
56 pub fn new(max_frame_bytes: usize) -> Self {
57 SupplierCodec { max_frame_bytes }
58 }
59}
60
61impl Decoder for SupplierCodec {
62 type Error = io::Error;
63 type Item = ConsumerRequest;
64
65 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
66 decode_length_checked_json(self.max_frame_bytes, src)
67 }
68}
69
70impl Encoder<SupplierResponse> for SupplierCodec {
71 type Error = io::Error;
72
73 fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
74 encode_length_checked_json(msg, dst)
75 }
76}
77
78fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
79 let mut work = dst.split_off(dst.len());
81
82 let zero_len = u64::MIN.to_be_bytes();
84 work.extend_from_slice(&zero_len);
85
86 let json_buf = work.split_off(zero_len.len());
92
93 let mut json_writer = json_buf.writer();
94
95 serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
96 error!(?err, "consumer encoding error");
97 io::Error::new(io::ErrorKind::Other, "JSON encode error")
98 })?;
99
100 let json_buf = json_writer.into_inner();
101
102 let final_len = json_buf.len() as u64;
103 let final_len_bytes = final_len.to_be_bytes();
104
105 if final_len_bytes.len() != work.len() {
106 error!("consumer buffer size error");
107 return Err(io::Error::new(io::ErrorKind::Other, "buffer length error"));
108 }
109
110 work.copy_from_slice(&final_len_bytes);
111
112 work.unsplit(json_buf);
114
115 dst.unsplit(work);
116
117 Ok(())
118}
119
120fn decode_length_checked_json<T: DeserializeOwned>(
121 max_frame_bytes: usize,
122 src: &mut BytesMut,
123) -> Result<Option<T>, io::Error> {
124 trace!(capacity = ?src.capacity());
125
126 if src.len() < 8 {
127 trace!("Insufficient bytes for length header.");
129 return Ok(None);
130 }
131
132 let (src_len_bytes, json_bytes) = src.split_at(8);
133 let mut len_be_bytes = [0; 8];
134
135 assert_eq!(len_be_bytes.len(), src_len_bytes.len());
136 len_be_bytes.copy_from_slice(src_len_bytes);
137 let req_len = u64::from_be_bytes(len_be_bytes);
138
139 if req_len == 0 {
140 error!("request has size 0");
141 return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
142 }
143
144 if req_len > max_frame_bytes as u64 {
145 error!(
146 "requested decode frame too large {} > {}",
147 req_len, max_frame_bytes
148 );
149 return Err(io::Error::new(
150 io::ErrorKind::OutOfMemory,
151 "request too large",
152 ));
153 }
154
155 if (json_bytes.len() as u64) < req_len {
156 trace!(
157 "Insufficient bytes for json, need: {} have: {}",
158 req_len,
159 src.len()
160 );
161 return Ok(None);
162 }
163
164 debug_assert!(req_len as usize <= json_bytes.len());
166 let (json_bytes, _remainder) = json_bytes.split_at(req_len as usize);
167
168 let res = serde_json::from_slice(json_bytes)
170 .map(|msg| Some(msg))
171 .map_err(|err| {
172 error!(?err, "received invalid input");
173 io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
174 });
175
176 if src.len() as u64 == req_len {
178 src.clear();
179 } else {
180 src.advance((8 + req_len) as usize);
181 };
182
183 res
184}
185
186#[cfg(test)]
187mod tests {
188 use bytes::BytesMut;
189 use tokio_util::codec::{Decoder, Encoder};
190
191 use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
192
193 #[test]
194 fn test_repl_codec() {
195 sketching::test_init();
196
197 let mut consumer_codec = ConsumerCodec::new(32);
198
199 let mut buf = BytesMut::with_capacity(32);
200
201 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
203
204 let zero = [0, 0, 0, 0];
205 buf.extend_from_slice(&zero);
206
207 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
209
210 let zero = [0, 0, 0, 0];
212 buf.extend_from_slice(&zero);
213 assert_eq!(buf.len(), 8);
214 assert!(consumer_codec.decode(&mut buf).is_err());
215
216 buf.clear();
218 let len_bytes = (34_u64).to_be_bytes();
219 buf.extend_from_slice(&len_bytes);
220
221 assert_eq!(buf.len(), 8);
224 assert!(consumer_codec.decode(&mut buf).is_err());
225
226 buf.clear();
228 let len_bytes = (20_u64).to_be_bytes();
229 buf.extend_from_slice(&len_bytes);
230 buf.extend_from_slice(&zero);
232 assert_eq!(buf.len(), 12);
233 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
234
235 buf.clear();
237 let mut supplier_codec = SupplierCodec::new(32);
238
239 assert!(consumer_codec
240 .encode(ConsumerRequest::Ping, &mut buf)
241 .is_ok());
242 assert!(matches!(
243 supplier_codec.decode(&mut buf),
244 Ok(Some(ConsumerRequest::Ping))
245 ));
246 assert!(buf.is_empty());
248 assert!(supplier_codec
249 .encode(SupplierResponse::Pong, &mut buf)
250 .is_ok());
251 assert!(matches!(
252 consumer_codec.decode(&mut buf),
253 Ok(Some(SupplierResponse::Pong))
254 ));
255 assert!(buf.is_empty());
256
257 buf.clear();
259 let mut supplier_codec = SupplierCodec::new(32);
260
261 assert!(consumer_codec
262 .encode(ConsumerRequest::Ping, &mut buf)
263 .is_ok());
264 assert!(consumer_codec
265 .encode(ConsumerRequest::Ping, &mut buf)
266 .is_ok());
267
268 assert!(matches!(
269 supplier_codec.decode(&mut buf),
270 Ok(Some(ConsumerRequest::Ping))
271 ));
272 assert!(!buf.is_empty());
273 assert!(matches!(
274 supplier_codec.decode(&mut buf),
275 Ok(Some(ConsumerRequest::Ping))
276 ));
277
278 assert!(buf.is_empty());
280 }
281}