kanidmd_core/repl/
codec.rs
1use bytes::{Buf, BufMut, BytesMut};
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::io;
4use tokio_util::codec::{Decoder, Encoder};
5
6use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
7
8pub const CODEC_MIMIMUM_BYTESMUT_ALLOCATION: usize = 1024 * 1024;
10pub const CODEC_BYTESMUT_ALLOCATION_LIMIT: usize = 8 * 1024 * 1024;
13
14#[derive(Serialize, Deserialize, Debug)]
15pub enum ConsumerRequest {
16 Ping,
17 Incremental(ReplRuvRange),
18 Refresh,
19}
20
21#[derive(Serialize, Deserialize, Debug)]
22pub enum SupplierResponse {
23 Pong,
24 Incremental(ReplIncrementalContext),
25 Refresh(ReplRefreshContext),
26}
27
28#[derive(Default)]
29pub struct ConsumerCodec {
30 max_frame_bytes: usize,
31}
32
33impl ConsumerCodec {
34 pub fn new(max_frame_bytes: usize) -> Self {
35 ConsumerCodec { max_frame_bytes }
36 }
37}
38
39impl Decoder for ConsumerCodec {
40 type Error = io::Error;
41 type Item = SupplierResponse;
42
43 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
44 decode_length_checked_json(self.max_frame_bytes, src)
45 }
46}
47
48impl Encoder<ConsumerRequest> for ConsumerCodec {
49 type Error = io::Error;
50
51 fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
52 encode_length_checked_json(msg, dst)
53 }
54}
55
56#[derive(Default)]
57pub struct SupplierCodec {
58 max_frame_bytes: usize,
59}
60
61impl SupplierCodec {
62 pub fn new(max_frame_bytes: usize) -> Self {
63 SupplierCodec { max_frame_bytes }
64 }
65}
66
67impl Decoder for SupplierCodec {
68 type Error = io::Error;
69 type Item = ConsumerRequest;
70
71 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
72 decode_length_checked_json(self.max_frame_bytes, src)
73 }
74}
75
76impl Encoder<SupplierResponse> for SupplierCodec {
77 type Error = io::Error;
78
79 fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
80 encode_length_checked_json(msg, dst)
81 }
82}
83
84fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
85 if dst.is_empty() && dst.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
88 dst.clear();
89 let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
90 std::mem::swap(&mut buf, dst);
91 }
92
93 let mut work = dst.split_off(dst.len());
95
96 let zero_len = u64::MIN.to_be_bytes();
98 work.extend_from_slice(&zero_len);
99
100 let json_buf = work.split_off(zero_len.len());
106
107 let mut json_writer = json_buf.writer();
108
109 serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
110 error!(?err, "consumer encoding error");
111 io::Error::other("JSON encode error")
112 })?;
113
114 let json_buf = json_writer.into_inner();
115
116 let final_len = json_buf.len() as u64;
117 let final_len_bytes = final_len.to_be_bytes();
118
119 if final_len_bytes.len() != work.len() {
120 error!("consumer buffer size error");
121 return Err(io::Error::other("buffer length error"));
122 }
123
124 work.copy_from_slice(&final_len_bytes);
125
126 work.unsplit(json_buf);
128
129 dst.unsplit(work);
130
131 Ok(())
132}
133
134fn decode_length_checked_json<T: DeserializeOwned>(
135 max_frame_bytes: usize,
136 src: &mut BytesMut,
137) -> Result<Option<T>, io::Error> {
138 trace!(capacity = ?src.capacity());
139
140 if src.len() < 8 {
141 trace!("Insufficient bytes for length header.");
143 return Ok(None);
144 }
145
146 let (src_len_bytes, json_bytes) = src.split_at(8);
147 let mut len_be_bytes = [0; 8];
148
149 assert_eq!(len_be_bytes.len(), src_len_bytes.len());
150 len_be_bytes.copy_from_slice(src_len_bytes);
151 let req_len = u64::from_be_bytes(len_be_bytes);
152
153 if req_len == 0 {
154 error!("request has size 0");
155 return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
156 }
157
158 if req_len > max_frame_bytes as u64 {
159 error!(
160 "requested decode frame too large {} > {}",
161 req_len, max_frame_bytes
162 );
163 return Err(io::Error::new(
164 io::ErrorKind::OutOfMemory,
165 "request too large",
166 ));
167 }
168
169 if (json_bytes.len() as u64) < req_len {
170 trace!(
171 "Insufficient bytes for json, need: {} have: {}",
172 req_len,
173 src.len()
174 );
175 return Ok(None);
176 }
177
178 debug_assert!(req_len as usize <= json_bytes.len());
180 let (json_bytes, _remainder) = json_bytes.split_at(req_len as usize);
181
182 let res = serde_json::from_slice(json_bytes)
184 .map(|msg| Some(msg))
185 .map_err(|err| {
186 error!(?err, "received invalid input");
187 io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
188 });
189
190 if src.len() as u64 == req_len {
192 src.clear();
193 if src.capacity() >= CODEC_BYTESMUT_ALLOCATION_LIMIT {
194 let mut buf = BytesMut::with_capacity(CODEC_MIMIMUM_BYTESMUT_ALLOCATION);
195 std::mem::swap(&mut buf, src);
196 }
197 } else {
198 src.advance((8 + req_len) as usize);
199 };
200
201 res
202}
203
204#[cfg(test)]
205mod tests {
206 use bytes::BytesMut;
207 use tokio_util::codec::{Decoder, Encoder};
208
209 use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
210
211 #[test]
212 fn test_repl_codec() {
213 sketching::test_init();
214
215 let mut consumer_codec = ConsumerCodec::new(32);
216
217 let mut buf = BytesMut::with_capacity(32);
218
219 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
221
222 let zero = [0, 0, 0, 0];
223 buf.extend_from_slice(&zero);
224
225 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
227
228 let zero = [0, 0, 0, 0];
230 buf.extend_from_slice(&zero);
231 assert_eq!(buf.len(), 8);
232 assert!(consumer_codec.decode(&mut buf).is_err());
233
234 buf.clear();
236 let len_bytes = (34_u64).to_be_bytes();
237 buf.extend_from_slice(&len_bytes);
238
239 assert_eq!(buf.len(), 8);
242 assert!(consumer_codec.decode(&mut buf).is_err());
243
244 buf.clear();
246 let len_bytes = (20_u64).to_be_bytes();
247 buf.extend_from_slice(&len_bytes);
248 buf.extend_from_slice(&zero);
250 assert_eq!(buf.len(), 12);
251 assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
252
253 buf.clear();
255 let mut supplier_codec = SupplierCodec::new(32);
256
257 assert!(consumer_codec
258 .encode(ConsumerRequest::Ping, &mut buf)
259 .is_ok());
260 assert!(matches!(
261 supplier_codec.decode(&mut buf),
262 Ok(Some(ConsumerRequest::Ping))
263 ));
264 assert!(buf.is_empty());
266 assert!(supplier_codec
267 .encode(SupplierResponse::Pong, &mut buf)
268 .is_ok());
269 assert!(matches!(
270 consumer_codec.decode(&mut buf),
271 Ok(Some(SupplierResponse::Pong))
272 ));
273 assert!(buf.is_empty());
274
275 buf.clear();
277 let mut supplier_codec = SupplierCodec::new(32);
278
279 assert!(consumer_codec
280 .encode(ConsumerRequest::Ping, &mut buf)
281 .is_ok());
282 assert!(consumer_codec
283 .encode(ConsumerRequest::Ping, &mut buf)
284 .is_ok());
285
286 assert!(matches!(
287 supplier_codec.decode(&mut buf),
288 Ok(Some(ConsumerRequest::Ping))
289 ));
290 assert!(!buf.is_empty());
291 assert!(matches!(
292 supplier_codec.decode(&mut buf),
293 Ok(Some(ConsumerRequest::Ping))
294 ));
295
296 assert!(buf.is_empty());
298 }
299}