1use crate::error::Error;
2use crate::kani;
3use crate::state::*;
4use std::collections::VecDeque;
5
6use std::sync::atomic::{AtomicU32, Ordering};
7use tokio::sync::Mutex;
8
9use std::sync::Arc;
10
11async fn apply_flags(client: Arc<kani::KanidmOrcaClient>, flags: &[Flag]) -> Result<(), Error> {
12 for flag in flags {
13 match flag {
14 Flag::DisableAllPersonsMFAPolicy => client.disable_mfa_requirement().await?,
15 Flag::ExtendPrivilegedAuthExpiry => client.extend_privilege_expiry().await?,
16 }
17 }
18 Ok(())
19}
20
21async fn preflight_person(
22 client: Arc<kani::KanidmOrcaClient>,
23 person: Person,
24) -> Result<(), Error> {
25 debug!(?person);
26
27 if client.person_exists(&person.username).await? {
28 return Ok(());
30 } else {
31 client
32 .person_create(&person.username, &person.display_name)
33 .await?;
34 }
35
36 match &person.credential {
37 Credential::Password { plain } => {
38 client
39 .person_set_primary_password_only(&person.username, plain)
40 .await?;
41 }
42 }
43
44 for role in &person.roles {
46 if let Some(need_groups) = role.requires_membership_to() {
47 for group_name in need_groups {
48 client
49 .group_add_members(group_name, &[person.username.as_str()])
50 .await?;
51 }
52 }
53 }
54
55 Ok(())
56}
57
58async fn preflight_group(client: Arc<kani::KanidmOrcaClient>, group: Group) -> Result<(), Error> {
59 if client.group_exists(&group.name.to_string()).await? {
60 } else {
62 client.group_create(&group.name.to_string()).await?;
63 }
64
65 let members = group.members.iter().map(|s| s.as_str()).collect::<Vec<_>>();
68
69 client
70 .group_set_members(&group.name.to_string(), members.as_slice())
71 .await?;
72
73 Ok(())
74}
75
76pub async fn preflight(state: State) -> Result<(), Error> {
77 let client = Arc::new(kani::KanidmOrcaClient::new(&state.profile).await?);
79
80 apply_flags(client.clone(), state.preflight_flags.as_slice()).await?;
82
83 let state_persons_len = state.persons.len();
84 let mut tasks = VecDeque::with_capacity(state_persons_len);
85
86 for person in state.persons.into_iter() {
88 let c = client.clone();
89 tasks.push_back(preflight_person(c, person))
93 }
94
95 let tasks = Arc::new(Mutex::new(tasks));
96 let counter = Arc::new(AtomicU32::new(0));
97 let par = std::thread::available_parallelism().unwrap();
98
99 let handles: Vec<_> = (0..par.into())
100 .map(|_| {
101 let tasks_q = tasks.clone();
102 let counter_c = counter.clone();
103 tokio::spawn(async move {
104 loop {
105 let maybe_task = async {
106 let mut guard = tasks_q.lock().await;
107 guard.pop_front()
108 }
109 .await;
110
111 if let Some(t) = maybe_task {
112 let _ = t.await;
113 let was = counter_c.fetch_add(1, Ordering::Relaxed);
114 if was % 1000 == 999 {
115 let order = was + 1;
116 eprint!("{}", order);
117 } else if was % 100 == 99 {
118 eprint!(".");
120 }
121 } else {
122 break;
124 }
125 }
126 })
127 })
128 .collect();
129
130 for handle in handles {
131 handle.await.map_err(|tokio_err| {
132 error!(?tokio_err, "Failed to join task");
133 Error::Tokio
134 })?;
135 }
136
137 eprintln!("done");
138
139 let counter = Arc::new(AtomicU32::new(0));
141 let mut tasks = Vec::with_capacity(state.groups.len());
142
143 for group in state.groups.into_iter() {
144 let c = client.clone();
145 tasks.push(preflight_group(c, group))
149 }
150
151 for task in tasks {
152 task.await?;
153 let was = counter.fetch_add(1, Ordering::Relaxed);
154 if was % 1000 == 999 {
155 let order = was + 1;
156 eprint!("{}", order);
157 } else if was % 100 == 99 {
158 eprint!(".");
160 }
161 }
162
163 eprintln!("done");
164
165 info!("Ready to 🛫");
168 Ok(())
169}