kanidmd_lib_macros/
entry.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use quote::{quote, quote_spanned, ToTokens};
4use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, ExprAssign, Token};
5
6fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
7    tokens.extend(TokenStream::from(error.into_compile_error()));
8    tokens
9}
10
11const ALLOWED_ATTRIBUTES: &[&str] = &["audit", "domain_level"];
12
13#[derive(Default)]
14struct Flags {
15    audit: bool,
16}
17
18fn parse_attributes(
19    args: &TokenStream,
20    input: &syn::ItemFn,
21) -> Result<(proc_macro2::TokenStream, Flags), syn::Error> {
22    let args: Punctuated<ExprAssign, syn::token::Comma> =
23        Punctuated::<ExprAssign, Token![,]>::parse_terminated.parse(args.clone())?;
24
25    let args_are_allowed = args.pairs().all(|p| {
26        ALLOWED_ATTRIBUTES.to_vec().contains(
27            &p.value()
28                .left
29                .span()
30                .source_text()
31                .unwrap_or_default()
32                .as_str(),
33        )
34    });
35
36    if !args_are_allowed {
37        let msg = "Invalid test config attribute. The following are allowed";
38        return Err(syn::Error::new_spanned(
39            input.sig.fn_token,
40            format!("{}: {}", msg, ALLOWED_ATTRIBUTES.join(", ")),
41        ));
42    }
43
44    let mut flags = Flags::default();
45    let mut field_modifications = quote! {};
46
47    args.pairs().for_each(|p| {
48        match p
49            .value()
50            .left
51            .span()
52            .source_text()
53            .unwrap_or_default()
54            .as_str()
55        {
56            "audit" => flags.audit = true,
57            _ => {
58                let field_name = p.value().left.to_token_stream(); // here we can use to_token_stream as we know we're iterating over ExprAssigns
59                let field_value = p.value().right.to_token_stream();
60                field_modifications.extend(quote! {
61                #field_name: #field_value,})
62            }
63        }
64    });
65
66    let ts = quote!(crate::testkit::TestConfiguration {
67        #field_modifications
68        ..crate::testkit::TestConfiguration::default()
69    });
70
71    Ok((ts, flags))
72}
73
74pub(crate) fn qs_test(args: TokenStream, item: TokenStream) -> TokenStream {
75    let input: syn::ItemFn = match syn::parse(item.clone()) {
76        Ok(it) => it,
77        Err(e) => return token_stream_with_error(item, e),
78    };
79
80    if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
81        let msg = "second test attribute is supplied";
82        return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
83    };
84
85    if input.sig.asyncness.is_none() {
86        let msg = "the `async` keyword is missing from the function declaration";
87        return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
88    }
89
90    // If type mismatch occurs, the current rustc points to the last statement.
91    let (last_stmt_start_span, _last_stmt_end_span) = {
92        let mut last_stmt = input
93            .block
94            .stmts
95            .last()
96            .map(ToTokens::into_token_stream)
97            .unwrap_or_default()
98            .into_iter();
99        // `Span` on stable Rust has a limitation that only points to the first
100        // token, not the whole tokens. We can work around this limitation by
101        // using the first/last span of the tokens like
102        // `syn::Error::new_spanned` does.
103        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
104        let end = last_stmt.last().map_or(start, |t| t.span());
105        (start, end)
106    };
107
108    // Setup the config filling the remaining fields with the default values
109    let (default_config_struct, _flags) = match parse_attributes(&args, &input) {
110        Ok(dc) => dc,
111        Err(e) => return token_stream_with_error(args, e),
112    };
113
114    let rt = quote_spanned! {last_stmt_start_span=>
115        tokio::runtime::Builder::new_current_thread()
116    };
117
118    let header = quote! {
119        #[::core::prelude::v1::test]
120    };
121
122    let test_fn = &input.sig.ident;
123    let test_driver = Ident::new(&format!("qs_{}", test_fn), input.sig.span());
124
125    // Effectively we are just injecting a real test function around this which we will
126    // call.
127
128    let result = quote! {
129        #input
130
131        #header
132        fn #test_driver() {
133            let body = async {
134                let test_config = #default_config_struct;
135
136                #[cfg(feature = "dhat-heap")]
137                let _profiler = dhat::Profiler::new_heap();
138
139                let test_server = crate::testkit::setup_test(test_config).await;
140
141                #test_fn(&test_server).await;
142
143                #[cfg(feature = "dhat-heap")]
144                drop(_profiler);
145
146                // Any needed teardown?
147                // Clear the cache before we verify.
148                assert!(test_server.clear_cache().await.is_ok());
149                // Make sure there are no errors.
150                let verifications = test_server.verify().await;
151                trace!("Verification result: {:?}", verifications);
152                assert_eq!(verifications.len(),0);
153            };
154            #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
155            {
156                return #rt
157                    .enable_all()
158                    .build()
159                    .expect("Failed building the Runtime")
160                    .block_on(body);
161            }
162        }
163    };
164
165    result.into()
166}
167
168pub(crate) fn qs_pair_test(args: &TokenStream, item: TokenStream) -> TokenStream {
169    let input: syn::ItemFn = match syn::parse(item.clone()) {
170        Ok(it) => it,
171        Err(e) => return token_stream_with_error(item, e),
172    };
173
174    if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
175        let msg = "second test attribute is supplied";
176        return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
177    };
178
179    if input.sig.asyncness.is_none() {
180        let msg = "the `async` keyword is missing from the function declaration";
181        return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
182    }
183
184    // If type mismatch occurs, the current rustc points to the last statement.
185    let (last_stmt_start_span, _last_stmt_end_span) = {
186        let mut last_stmt = input
187            .block
188            .stmts
189            .last()
190            .map(ToTokens::into_token_stream)
191            .unwrap_or_default()
192            .into_iter();
193        // `Span` on stable Rust has a limitation that only points to the first
194        // token, not the whole tokens. We can work around this limitation by
195        // using the first/last span of the tokens like
196        // `syn::Error::new_spanned` does.
197        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
198        let end = last_stmt.last().map_or(start, |t| t.span());
199        (start, end)
200    };
201
202    let rt = quote_spanned! {last_stmt_start_span=>
203        tokio::runtime::Builder::new_current_thread()
204    };
205
206    let header = quote! {
207        #[::core::prelude::v1::test]
208    };
209
210    // Setup the config filling the remaining fields with the default values
211    let (default_config_struct, _flags) = match parse_attributes(args, &input) {
212        Ok(dc) => dc,
213        Err(e) => return token_stream_with_error(args.clone(), e),
214    };
215
216    let test_fn = &input.sig.ident;
217    let test_driver = Ident::new(&format!("qs_{}", test_fn), input.sig.span());
218
219    // Effectively we are just injecting a real test function around this which we will
220    // call.
221
222    let result = quote! {
223        #input
224
225        #header
226        fn #test_driver() {
227            let body = async {
228                let test_config = #default_config_struct;
229
230                #[cfg(feature = "dhat-heap")]
231                let _profiler = dhat::Profiler::new_heap();
232
233                let (server_a, server_b) = crate::testkit::setup_pair_test(test_config).await;
234
235                #test_fn(&server_a, &server_b).await;
236
237                #[cfg(feature = "dhat-heap")]
238                drop(_profiler);
239
240                // Any needed teardown?
241                assert!(server_a.clear_cache().await.is_ok());
242                assert!(server_b.clear_cache().await.is_ok());
243                // Make sure there are no errors.
244                let verifications_a = server_a.verify().await;
245                let verifications_b = server_b.verify().await;
246                trace!("Verification result: {:?}, {:?}", verifications_a, verifications_b);
247                assert!(verifications_a.len() + verifications_b.len() == 0);
248            };
249            #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
250            {
251                return #rt
252                    .enable_all()
253                    .build()
254                    .expect("Failed building the Runtime")
255                    .block_on(body);
256            }
257        }
258    };
259
260    result.into()
261}
262
263pub(crate) fn idm_test(args: &TokenStream, item: TokenStream) -> TokenStream {
264    let input: syn::ItemFn = match syn::parse(item.clone()) {
265        Ok(it) => it,
266        Err(e) => return token_stream_with_error(item, e),
267    };
268
269    if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
270        let msg = "second test attribute is supplied";
271        return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
272    };
273
274    if input.sig.asyncness.is_none() {
275        let msg = "the `async` keyword is missing from the function declaration";
276        return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
277    }
278
279    // If type mismatch occurs, the current rustc points to the last statement.
280    let (last_stmt_start_span, _last_stmt_end_span) = {
281        let mut last_stmt = input
282            .block
283            .stmts
284            .last()
285            .map(ToTokens::into_token_stream)
286            .unwrap_or_default()
287            .into_iter();
288        // `Span` on stable Rust has a limitation that only points to the first
289        // token, not the whole tokens. We can work around this limitation by
290        // using the first/last span of the tokens like
291        // `syn::Error::new_spanned` does.
292        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
293        let end = last_stmt.last().map_or(start, |t| t.span());
294        (start, end)
295    };
296
297    // Setup the config filling the remaining fields with the default values
298    let (default_config_struct, flags) = match parse_attributes(args, &input) {
299        Ok(dc) => dc,
300        Err(e) => return token_stream_with_error(args.clone(), e),
301    };
302
303    let rt = quote_spanned! {last_stmt_start_span=>
304        tokio::runtime::Builder::new_current_thread()
305    };
306
307    let header = quote! {
308        #[::core::prelude::v1::test]
309    };
310
311    let test_fn = &input.sig.ident;
312    let test_driver = Ident::new(&format!("idm_{}", test_fn), input.sig.span());
313
314    let test_fn_args = if flags.audit {
315        quote! {
316            &test_server, &mut idms_delayed, &mut idms_audit
317        }
318    } else {
319        quote! {
320            &test_server, &mut idms_delayed
321        }
322    };
323
324    // Effectively we are just injecting a real test function around this which we will
325    // call.
326
327    let result = quote! {
328        #input
329
330        #header
331        fn #test_driver() {
332            let body = async {
333                let test_config = #default_config_struct;
334
335                #[cfg(feature = "dhat-heap")]
336                let _profiler = dhat::Profiler::new_heap();
337
338                let (test_server, mut idms_delayed, mut idms_audit)  = crate::testkit::setup_idm_test(test_config).await;
339
340                #test_fn(#test_fn_args).await;
341
342                #[cfg(feature = "dhat-heap")]
343                drop(_profiler);
344
345                // Any needed teardown?
346                // assert!(test_server.clear_cache().await.is_ok());
347                // Make sure there are no errors.
348                let mut idm_read_txn = test_server.proxy_read().await.unwrap();
349                let verifications = idm_read_txn.qs_read.verify();
350                trace!("Verification result: {:?}", verifications);
351                assert_eq!(verifications.len(),0);
352
353                idms_delayed.check_is_empty_or_panic();
354                idms_audit.check_is_empty_or_panic();
355            };
356            #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
357            {
358                return #rt
359                    .enable_all()
360                    .build()
361                    .expect("Failed building the Runtime")
362                    .block_on(body);
363            }
364        }
365    };
366
367    result.into()
368}