testkit_macros/
entry.rs
1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use syn::{parse::Parser, punctuated::Punctuated, spanned::Spanned, ExprAssign, Token};
4
5use quote::{quote, quote_spanned, ToTokens};
6
7const ALLOWED_ATTRIBUTES: &[&str] = &[
10 "threads",
11 "db_path",
12 "maximum_request",
13 "http_client_address_info",
14 "role",
15 "output_mode",
16 "log_level",
17 "ldap",
18 "with_test_env",
19];
20
21#[derive(Default)]
22struct Flags {
23 target_wants_test_env: bool,
24}
25
26fn parse_attributes(
27 args: &TokenStream,
28 input: &syn::ItemFn,
29) -> Result<(proc_macro2::TokenStream, Flags), syn::Error> {
30 let args: Punctuated<ExprAssign, syn::token::Comma> =
31 Punctuated::<ExprAssign, Token![,]>::parse_terminated.parse(args.clone())?;
32
33 let args_are_allowed = args.pairs().all(|p| {
34 ALLOWED_ATTRIBUTES.to_vec().contains(
35 &p.value()
36 .left
37 .span()
38 .source_text()
39 .unwrap_or_default()
40 .as_str(),
41 )
42 });
43
44 if !args_are_allowed {
45 let msg = "Invalid test config attribute. The following are allowed";
46 return Err(syn::Error::new_spanned(
47 input.sig.fn_token,
48 format!("{}: {}", msg, ALLOWED_ATTRIBUTES.join(", ")),
49 ));
50 }
51
52 let mut flags = Flags::default();
53 let mut field_modifications = quote! {};
54
55 args.pairs().for_each(|p| {
56 match p
57 .value()
58 .left
59 .span()
60 .source_text()
61 .unwrap_or_default()
62 .as_str()
63 {
64 "with_test_env" => {
65 flags.target_wants_test_env = true;
66 }
67 "ldap" => {
68 flags.target_wants_test_env = true;
69 field_modifications.extend(quote! {
70 ldapbindaddress: Some("on".to_string()),})
71 }
72 _ => {
73 let field_name = p.value().left.to_token_stream(); let field_value = p.value().right.to_token_stream();
75 field_modifications.extend(quote! {
77 #field_name: #field_value,})
78 }
79 }
80 });
81
82 let ts = quote!(kanidmd_core::config::Configuration {
83 #field_modifications
84 ..kanidmd_core::config::Configuration::new_for_test()
85 });
86
87 Ok((ts, flags))
88}
89
90pub(crate) fn test(args: TokenStream, item: TokenStream) -> TokenStream {
91 let input: syn::ItemFn = match syn::parse(item.clone()) {
95 Ok(it) => it,
96 Err(e) => return token_stream_with_error(item, e),
97 };
98
99 if let Some(attr) = input.attrs.iter().find(|attr| attr.path().is_ident("test")) {
100 let msg = "second test attribute is supplied";
101 return token_stream_with_error(item, syn::Error::new_spanned(attr, msg));
102 };
103
104 if input.sig.asyncness.is_none() {
105 let msg = "the `async` keyword is missing from the function declaration";
106 return token_stream_with_error(item, syn::Error::new_spanned(input.sig.fn_token, msg));
107 }
108
109 let (last_stmt_start_span, _last_stmt_end_span) = {
111 let mut last_stmt = input
112 .block
113 .stmts
114 .last()
115 .map(ToTokens::into_token_stream)
116 .unwrap_or_default()
117 .into_iter();
118 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
123 let end = last_stmt.last().map_or(start, |t| t.span());
124 (start, end)
125 };
126
127 let (default_config_struct, flags) = match parse_attributes(&args, &input) {
129 Ok(dc) => dc,
130 Err(e) => return token_stream_with_error(args, e),
131 };
132
133 let rt = quote_spanned! {last_stmt_start_span=>
134 tokio::runtime::Builder::new_current_thread()
135 };
136
137 let header = quote! {
138 #[::core::prelude::v1::test]
139 };
140
141 let test_fn_args = if flags.target_wants_test_env {
142 quote! {
143 &test_env
144 }
145 } else {
146 quote! {
147 &test_env.rsclient
148 }
149 };
150
151 let test_fn = &input.sig.ident;
152 let test_driver = Ident::new(&format!("tk_{}", test_fn), input.sig.span());
153
154 let result = quote! {
157 #input
158
159 #header
160 fn #test_driver() {
161 let body = async {
162 let mut test_env = kanidmd_testkit::setup_async_test(#default_config_struct).await;
163
164 #test_fn(#test_fn_args).await;
165 test_env.core_handle.shutdown().await;
166 };
167 #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
168 {
169 return #rt
170 .enable_all()
171 .build()
172 .expect("Failed building the Runtime")
173 .block_on(body);
174 }
175 }
176 };
177
178 result.into()
179}
180
181fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
182 tokens.extend(TokenStream::from(error.into_compile_error()));
183 tokens
184}