1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use quote::{quote, quote_spanned, ToTokens};
4use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, ExprAssign, Token};
5
6fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
7 tokens.extend(TokenStream::from(error.into_compile_error()));
8 tokens
9}
10
11const ALLOWED_ATTRIBUTES: &[&str] = &["audit", "domain_level"];
12
13#[derive(Default)]
14struct Flags {
15 audit: bool,
16}
17
18fn parse_attributes(
19 args: &TokenStream,
20 input: &syn::ItemFn,
21) -> Result<(proc_macro2::TokenStream, Flags), syn::Error> {
22 let args: Punctuated<ExprAssign, syn::token::Comma> =
23 Punctuated::<ExprAssign, Token![,]>::parse_terminated.parse(args.clone())?;
24
25 let args_are_allowed = args.pairs().all(|p| {
26 ALLOWED_ATTRIBUTES.to_vec().contains(
27 &p.value()
28 .left
29 .span()
30 .source_text()
31 .unwrap_or_default()
32 .as_str(),
33 )
34 });
35
36 if !args_are_allowed {
37 let msg = "Invalid test config attribute. The following are allowed";
38 return Err(syn::Error::new_spanned(
39 input.sig.fn_token,
40 format!("{}: {}", msg, ALLOWED_ATTRIBUTES.join(", ")),
41 ));
42 }
43
44 let mut flags = Flags::default();
45 let mut field_modifications = quote! {};
46
47 args.pairs().for_each(|p| {
48 match p
49 .value()
50 .left
51 .span()
52 .source_text()
53 .unwrap_or_default()
54 .as_str()
55 {
56 "audit" => flags.audit = true,
57 _ => {
58 let field_name = p.value().left.to_token_stream(); let field_value = p.value().right.to_token_stream();
60 field_modifications.extend(quote! {
61 #field_name: #field_value,})
62 }
63 }
64 });
65
66 let ts = quote!(crate::testkit::TestConfiguration {
67 #field_modifications
68 ..crate::testkit::TestConfiguration::default()
69 });
70
71 Ok((ts, flags))
72}
73
74pub(crate) fn qs_test(args: TokenStream, item: TokenStream) -> TokenStream {
75 let input: syn::ItemFn = match syn::parse(item.clone()) {
76 Ok(it) => it,
77 Err(e) => return token_stream_with_error(item, e),
78 };
79
80 if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
81 let msg = "second test attribute is supplied";
82 return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
83 };
84
85 if input.sig.asyncness.is_none() {
86 let msg = "the `async` keyword is missing from the function declaration";
87 return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
88 }
89
90 let (last_stmt_start_span, _last_stmt_end_span) = {
92 let mut last_stmt = input
93 .block
94 .stmts
95 .last()
96 .map(ToTokens::into_token_stream)
97 .unwrap_or_default()
98 .into_iter();
99 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
104 let end = last_stmt.last().map_or(start, |t| t.span());
105 (start, end)
106 };
107
108 let (default_config_struct, _flags) = match parse_attributes(&args, &input) {
110 Ok(dc) => dc,
111 Err(e) => return token_stream_with_error(args, e),
112 };
113
114 let rt = quote_spanned! {last_stmt_start_span=>
115 tokio::runtime::Builder::new_current_thread()
116 };
117
118 let header = quote! {
119 #[::core::prelude::v1::test]
120 };
121
122 let test_fn = &input.sig.ident;
123 let test_driver = Ident::new(&format!("qs_{}", test_fn), input.sig.span());
124
125 let result = quote! {
129 #input
130
131 #header
132 fn #test_driver() {
133 let body = async {
134 let test_config = #default_config_struct;
135
136 #[cfg(feature = "dhat-heap")]
137 let _profiler = dhat::Profiler::new_heap();
138
139 let test_server = crate::testkit::setup_test(test_config).await;
140
141 #test_fn(&test_server).await;
142
143 #[cfg(feature = "dhat-heap")]
144 drop(_profiler);
145
146 assert!(test_server.clear_cache().await.is_ok());
149 let verifications = test_server.verify().await;
151 trace!("Verification result: {:?}", verifications);
152 assert_eq!(verifications.len(),0);
153 };
154 #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
155 {
156 return #rt
157 .enable_all()
158 .build()
159 .expect("Failed building the Runtime")
160 .block_on(body);
161 }
162 }
163 };
164
165 result.into()
166}
167
168pub(crate) fn qs_pair_test(args: &TokenStream, item: TokenStream) -> TokenStream {
169 let input: syn::ItemFn = match syn::parse(item.clone()) {
170 Ok(it) => it,
171 Err(e) => return token_stream_with_error(item, e),
172 };
173
174 if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
175 let msg = "second test attribute is supplied";
176 return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
177 };
178
179 if input.sig.asyncness.is_none() {
180 let msg = "the `async` keyword is missing from the function declaration";
181 return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
182 }
183
184 let (last_stmt_start_span, _last_stmt_end_span) = {
186 let mut last_stmt = input
187 .block
188 .stmts
189 .last()
190 .map(ToTokens::into_token_stream)
191 .unwrap_or_default()
192 .into_iter();
193 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
198 let end = last_stmt.last().map_or(start, |t| t.span());
199 (start, end)
200 };
201
202 let rt = quote_spanned! {last_stmt_start_span=>
203 tokio::runtime::Builder::new_current_thread()
204 };
205
206 let header = quote! {
207 #[::core::prelude::v1::test]
208 };
209
210 let (default_config_struct, _flags) = match parse_attributes(args, &input) {
212 Ok(dc) => dc,
213 Err(e) => return token_stream_with_error(args.clone(), e),
214 };
215
216 let test_fn = &input.sig.ident;
217 let test_driver = Ident::new(&format!("qs_{}", test_fn), input.sig.span());
218
219 let result = quote! {
223 #input
224
225 #header
226 fn #test_driver() {
227 let body = async {
228 let test_config = #default_config_struct;
229
230 #[cfg(feature = "dhat-heap")]
231 let _profiler = dhat::Profiler::new_heap();
232
233 let (server_a, server_b) = crate::testkit::setup_pair_test(test_config).await;
234
235 #test_fn(&server_a, &server_b).await;
236
237 #[cfg(feature = "dhat-heap")]
238 drop(_profiler);
239
240 assert!(server_a.clear_cache().await.is_ok());
242 assert!(server_b.clear_cache().await.is_ok());
243 let verifications_a = server_a.verify().await;
245 let verifications_b = server_b.verify().await;
246 trace!("Verification result: {:?}, {:?}", verifications_a, verifications_b);
247 assert!(verifications_a.len() + verifications_b.len() == 0);
248 };
249 #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
250 {
251 return #rt
252 .enable_all()
253 .build()
254 .expect("Failed building the Runtime")
255 .block_on(body);
256 }
257 }
258 };
259
260 result.into()
261}
262
263pub(crate) fn idm_test(args: &TokenStream, item: TokenStream) -> TokenStream {
264 let input: syn::ItemFn = match syn::parse(item.clone()) {
265 Ok(it) => it,
266 Err(e) => return token_stream_with_error(item, e),
267 };
268
269 if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
270 let msg = "second test attribute is supplied";
271 return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
272 };
273
274 if input.sig.asyncness.is_none() {
275 let msg = "the `async` keyword is missing from the function declaration";
276 return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
277 }
278
279 let (last_stmt_start_span, _last_stmt_end_span) = {
281 let mut last_stmt = input
282 .block
283 .stmts
284 .last()
285 .map(ToTokens::into_token_stream)
286 .unwrap_or_default()
287 .into_iter();
288 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
293 let end = last_stmt.last().map_or(start, |t| t.span());
294 (start, end)
295 };
296
297 let (default_config_struct, flags) = match parse_attributes(args, &input) {
299 Ok(dc) => dc,
300 Err(e) => return token_stream_with_error(args.clone(), e),
301 };
302
303 let rt = quote_spanned! {last_stmt_start_span=>
304 tokio::runtime::Builder::new_current_thread()
305 };
306
307 let header = quote! {
308 #[::core::prelude::v1::test]
309 };
310
311 let test_fn = &input.sig.ident;
312 let test_driver = Ident::new(&format!("idm_{}", test_fn), input.sig.span());
313
314 let test_fn_args = if flags.audit {
315 quote! {
316 &test_server, &mut idms_delayed, &mut idms_audit
317 }
318 } else {
319 quote! {
320 &test_server, &mut idms_delayed
321 }
322 };
323
324 let result = quote! {
328 #input
329
330 #header
331 fn #test_driver() {
332 let body = async {
333 let test_config = #default_config_struct;
334
335 #[cfg(feature = "dhat-heap")]
336 let _profiler = dhat::Profiler::new_heap();
337
338 let (test_server, mut idms_delayed, mut idms_audit) = crate::testkit::setup_idm_test(test_config).await;
339
340 #test_fn(#test_fn_args).await;
341
342 #[cfg(feature = "dhat-heap")]
343 drop(_profiler);
344
345 let mut idm_read_txn = test_server.proxy_read().await.unwrap();
349 let verifications = idm_read_txn.qs_read.verify();
350 trace!("Verification result: {:?}", verifications);
351 assert_eq!(verifications.len(),0);
352
353 idms_delayed.check_is_empty_or_panic();
354 idms_audit.check_is_empty_or_panic();
355 };
356 #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
357 {
358 return #rt
359 .enable_all()
360 .build()
361 .expect("Failed building the Runtime")
362 .block_on(body);
363 }
364 }
365 };
366
367 result.into()
368}