testkit_macros/
entry.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, ExprAssign, Token};
4
5use quote::{quote, quote_spanned, ToTokens};
6
7// for now we only allow a subset of the configuration to be tweaked, but it can be expanded in the future as needed
8
9const ALLOWED_ATTRIBUTES: &[&str] = &[
10    "threads",
11    "db_path",
12    "maximum_request",
13    "http_client_address_info",
14    "role",
15    "output_mode",
16    "log_level",
17    "ldap",
18    "with_test_env",
19];
20
21#[derive(Default)]
22struct Flags {
23    target_wants_test_env: bool,
24}
25
26fn parse_attributes(
27    args: &TokenStream,
28    input: &syn::ItemFn,
29) -> Result<(proc_macro2::TokenStream, Flags), syn::Error> {
30    let args: Punctuated<ExprAssign, syn::token::Comma> =
31        Punctuated::<ExprAssign, Token![,]>::parse_terminated.parse(args.clone())?;
32
33    let args_are_allowed = args.pairs().all(|p| {
34        ALLOWED_ATTRIBUTES.to_vec().contains(
35            &p.value()
36                .left
37                .span()
38                .source_text()
39                .unwrap_or_default()
40                .as_str(),
41        )
42    });
43
44    if !args_are_allowed {
45        let msg = "Invalid test config attribute. The following are allowed";
46        return Err(syn::Error::new_spanned(
47            input.sig.fn_token,
48            format!("{}: {}", msg, ALLOWED_ATTRIBUTES.join(", ")),
49        ));
50    }
51
52    let mut flags = Flags::default();
53    let mut field_modifications = quote! {};
54
55    args.pairs().for_each(|p| {
56        match p
57            .value()
58            .left
59            .span()
60            .source_text()
61            .unwrap_or_default()
62            .as_str()
63        {
64            "with_test_env" => {
65                flags.target_wants_test_env = true;
66            }
67            "ldap" => {
68                flags.target_wants_test_env = true;
69                field_modifications.extend(quote! {
70                ldapbindaddress: Some("on".to_string()),})
71            }
72            _ => {
73                let field_name = p.value().left.to_token_stream(); // here we can use to_token_stream as we know we're iterating over ExprAssigns
74                let field_value = p.value().right.to_token_stream();
75                // This is printing out struct members.
76                field_modifications.extend(quote! {
77                #field_name: #field_value,})
78            }
79        }
80    });
81
82    let ts = quote!(kanidmd_core::config::Configuration {
83        #field_modifications
84        ..kanidmd_core::config::Configuration::new_for_test()
85    });
86
87    Ok((ts, flags))
88}
89
90pub(crate) fn test(args: TokenStream, item: TokenStream) -> TokenStream {
91    // If any of the steps for this macro fail, we still want to expand to an item that is as close
92    // to the expected output as possible. This helps out IDEs such that completions and other
93    // related features keep working.
94    let input: syn::ItemFn = match syn::parse(item.clone()) {
95        Ok(it) => it,
96        Err(e) => return token_stream_with_error(item, e),
97    };
98
99    if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
100        let msg = "second test attribute is supplied";
101        return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
102    };
103
104    if input.sig.asyncness.is_none() {
105        let msg = "the `async` keyword is missing from the function declaration";
106        return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
107    }
108
109    // If type mismatch occurs, the current rustc points to the last statement.
110    let (last_stmt_start_span, _last_stmt_end_span) = {
111        let mut last_stmt = input
112            .block
113            .stmts
114            .last()
115            .map(ToTokens::into_token_stream)
116            .unwrap_or_default()
117            .into_iter();
118        // `Span` on stable Rust has a limitation that only points to the first
119        // token, not the whole tokens. We can work around this limitation by
120        // using the first/last span of the tokens like
121        // `syn::Error::new_spanned` does.
122        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
123        let end = last_stmt.last().map_or(start, |t| t.span());
124        (start, end)
125    };
126
127    // Setup the config filling the remaining fields with the default values
128    let (default_config_struct, flags) = match parse_attributes(&args, &input) {
129        Ok(dc) => dc,
130        Err(e) => return token_stream_with_error(args, e),
131    };
132
133    let rt = quote_spanned! {last_stmt_start_span=>
134        tokio::runtime::Builder::new_current_thread()
135    };
136
137    let header = quote! {
138        #[::core::prelude::v1::test]
139    };
140
141    let test_fn_args = if flags.target_wants_test_env {
142        quote! {
143            &test_env
144        }
145    } else {
146        quote! {
147            &test_env.rsclient
148        }
149    };
150
151    let test_fn = &input.sig.ident;
152    let test_driver = Ident::new(&format!("tk_{}", test_fn), input.sig.span());
153
154    // Effectively we are just injecting a real test function around this which we will
155    // call.
156    let result = quote! {
157        #input
158
159        #header
160        fn #test_driver() {
161            let body = async {
162                let mut test_env = kanidmd_testkit::setup_async_test(#default_config_struct).await;
163
164                #test_fn(#test_fn_args).await;
165                test_env.core_handle.shutdown().await;
166            };
167            #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
168            {
169                return #rt
170                    .enable_all()
171                    .build()
172                    .expect("Failed building the Runtime")
173                    .block_on(body);
174            }
175        }
176    };
177
178    result.into()
179}
180
181fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
182    tokens.extend(TokenStream::from(error.into_compile_error()));
183    tokens
184}