kanidmd_lib/server/
batch_modify.rs

1use super::{ChangeFlag, QueryServerWriteTransaction};
2use crate::prelude::*;
3use crate::server::Plugins;
4use std::collections::BTreeMap;
5
6pub type ModSetValid = BTreeMap<Uuid, ModifyList<ModifyValid>>;
7
8pub struct BatchModifyEvent {
9    pub ident: Identity,
10    pub modset: ModSetValid,
11}
12
13impl QueryServerWriteTransaction<'_> {
14    /// This function behaves different to modify. Modify applies the same
15    /// modification operation en-mass to 1 -> N entries. This takes a set of modifications
16    /// that define a precise entry to apply a change to and only modifies that.
17    ///
18    /// modify is for all entries matching this condition, do this change.
19    ///
20    /// batch_modify is for entry X apply mod A, for entry Y apply mod B etc. It allows you
21    /// to do per-entry mods.
22    ///
23    /// The drawback is you need to know ahead of time what uuids you are affecting. This
24    /// has parallels to scim, so it's not a significant issue.
25    ///
26    /// Otherwise, we follow the same pattern here as modify, and inside the transform
27    /// the same modlists are used.
28    #[instrument(level = "debug", skip_all)]
29    pub fn batch_modify(&mut self, me: &BatchModifyEvent) -> Result<(), OperationError> {
30        // ⚠️  =========
31        // Effectively this is the same as modify but instead of apply modlist
32        // we do it by uuid.
33
34        // Get the candidates.
35        // Modify applies a modlist to a filter, so we need to internal search
36        // then apply.
37        if !me.ident.is_internal() {
38            security_info!(name = %me.ident, "batch modify initiator");
39        }
40
41        // Validate input.
42
43        // Is the modlist non zero?
44        if me.modset.is_empty() {
45            request_error!("empty modify request");
46            return Err(OperationError::EmptyRequest);
47        }
48
49        let filter_or = me
50            .modset
51            .keys()
52            .copied()
53            .map(|u| f_eq(Attribute::Uuid, PartialValue::Uuid(u)))
54            .collect();
55
56        let filter = filter_all!(f_or(filter_or))
57            .validate(self.get_schema())
58            .map_err(OperationError::SchemaViolation)?;
59
60        // This also checks access controls due to use of the impersonation.
61        let pre_candidates = self
62            .impersonate_search_valid(filter.clone(), filter.clone(), &me.ident)
63            .map_err(|e| {
64                admin_error!("error in pre-candidate selection {:?}", e);
65                e
66            })?;
67
68        if pre_candidates.is_empty() {
69            if me.ident.is_internal() {
70                trace!("no candidates match filter ... continuing {:?}", filter);
71                return Ok(());
72            } else {
73                request_error!("no candidates match modset request, failure {:?}", filter);
74                return Err(OperationError::NoMatchingEntries);
75            }
76        };
77
78        if pre_candidates.len() != me.modset.len() {
79            error!("Inconsistent modify, some uuids were not found in request.");
80            return Err(OperationError::MissingEntries);
81        }
82
83        trace!("pre_candidates -> {:?}", pre_candidates);
84        trace!("modset -> {:?}", me.modset);
85
86        // Are we allowed to make the changes we want to?
87        // modify_allow_operation
88        let access = self.get_accesscontrols();
89
90        let op_allow = access
91            .batch_modify_allow_operation(me, &pre_candidates)
92            .map_err(|e| {
93                admin_error!("Unable to check batch modify access {:?}", e);
94                e
95            })?;
96        if !op_allow {
97            return Err(OperationError::AccessDenied);
98        }
99
100        // Clone a set of writeables.
101        // Apply the modlist -> Remember, we have a set of origs
102        // and the new modified ents.
103        // =========
104        // The primary difference to modify is here - notice we do per-uuid mods.
105        let mut candidates = pre_candidates
106            .iter()
107            .map(|er| {
108                let u = er.get_uuid();
109                let mut ent_mut = er
110                    .as_ref()
111                    .clone()
112                    .invalidate(self.cid.clone(), &self.trim_cid);
113
114                me.modset
115                    .get(&u)
116                    .ok_or_else(|| {
117                        error!("No entry for uuid {} was found, aborting", u);
118                        OperationError::NoMatchingEntries
119                    })
120                    .and_then(|modlist| {
121                        ent_mut
122                            .apply_modlist(modlist)
123                            // Return if success
124                            .map(|()| ent_mut)
125                            // Error log otherwise.
126                            .inspect_err(|_e| {
127                                error!("Modification failed for {}", u);
128                            })
129                    })
130            })
131            .collect::<Result<Vec<EntryInvalidCommitted>, _>>()?;
132
133        // Did any of the candidates now become masked?
134        if std::iter::zip(
135            pre_candidates
136                .iter()
137                .map(|e| e.mask_recycled_ts().is_none()),
138            candidates.iter().map(|e| e.mask_recycled_ts().is_none()),
139        )
140        .any(|(a, b)| a != b)
141        {
142            admin_warn!("Refusing to apply modifications that are attempting to bypass replication state machine.");
143            return Err(OperationError::AccessDenied);
144        }
145
146        // Pre mod plugins
147        // We should probably supply the pre-post cands here.
148        Plugins::run_pre_batch_modify(self, &pre_candidates, &mut candidates, me).map_err(|e| {
149            admin_error!("Pre-Modify operation failed (plugin), {:?}", e);
150            e
151        })?;
152
153        let norm_cand = candidates
154            .into_iter()
155            .map(|entry| {
156                entry
157                    .validate(&self.schema)
158                    .map_err(|e| {
159                        admin_error!("Schema Violation in validation of modify_pre_apply {:?}", e);
160                        OperationError::SchemaViolation(e)
161                    })
162                    .map(|entry| entry.seal(&self.schema))
163            })
164            .collect::<Result<Vec<EntrySealedCommitted>, _>>()?;
165
166        // Backend Modify
167        self.be_txn
168            .modify(&self.cid, &pre_candidates, &norm_cand)
169            .map_err(|e| {
170                admin_error!("Modify operation failed (backend), {:?}", e);
171                e
172            })?;
173
174        // Post Plugins
175        //
176        // memberOf actually wants the pre cand list and the norm_cand list to see what
177        // changed. Could be optimised, but this is correct still ...
178        Plugins::run_post_batch_modify(self, &pre_candidates, &norm_cand, me).map_err(|e| {
179            admin_error!("Post-Modify operation failed (plugin), {:?}", e);
180            e
181        })?;
182
183        // We have finished all plugs and now have a successful operation - flag if
184        // schema or acp requires reload. Remember, this is a modify, so we need to check
185        // pre and post cands.
186        if !self.changed_flags.contains(ChangeFlag::SCHEMA)
187            && norm_cand
188                .iter()
189                .chain(pre_candidates.iter().map(|e| e.as_ref()))
190                .any(|e| {
191                    e.attribute_equality(Attribute::Class, &EntryClass::ClassType.into())
192                        || e.attribute_equality(Attribute::Class, &EntryClass::AttributeType.into())
193                })
194        {
195            self.changed_flags.insert(ChangeFlag::SCHEMA)
196        }
197
198        if !self.changed_flags.contains(ChangeFlag::ACP)
199            && norm_cand
200                .iter()
201                .chain(pre_candidates.iter().map(|e| e.as_ref()))
202                .any(|e| {
203                    e.attribute_equality(Attribute::Class, &EntryClass::AccessControlProfile.into())
204                })
205        {
206            self.changed_flags.insert(ChangeFlag::ACP)
207        }
208
209        if !self.changed_flags.contains(ChangeFlag::APPLICATION)
210            && norm_cand
211                .iter()
212                .chain(pre_candidates.iter().map(|e| e.as_ref()))
213                .any(|e| e.attribute_equality(Attribute::Class, &EntryClass::Application.into()))
214        {
215            self.changed_flags.insert(ChangeFlag::APPLICATION)
216        }
217
218        if !self.changed_flags.contains(ChangeFlag::OAUTH2)
219            && norm_cand
220                .iter()
221                .chain(pre_candidates.iter().map(|e| e.as_ref()))
222                .any(|e| {
223                    e.attribute_equality(Attribute::Class, &EntryClass::OAuth2ResourceServer.into())
224                })
225        {
226            self.changed_flags.insert(ChangeFlag::OAUTH2)
227        }
228
229        if !self.changed_flags.contains(ChangeFlag::OAUTH2_CLIENT)
230            && norm_cand
231                .iter()
232                .chain(pre_candidates.iter().map(|e| e.as_ref()))
233                .any(|e| e.attribute_equality(Attribute::Class, &EntryClass::OAuth2Client.into()))
234        {
235            self.changed_flags.insert(ChangeFlag::OAUTH2_CLIENT)
236        }
237
238        if !self.changed_flags.contains(ChangeFlag::FEATURE)
239            && norm_cand
240                .iter()
241                .chain(pre_candidates.iter().map(|e| e.as_ref()))
242                .any(|e| e.attribute_equality(Attribute::Class, &EntryClass::Feature.into()))
243        {
244            self.changed_flags.insert(ChangeFlag::FEATURE)
245        }
246
247        if !self.changed_flags.contains(ChangeFlag::DOMAIN)
248            && norm_cand
249                .iter()
250                .chain(pre_candidates.iter().map(|e| e.as_ref()))
251                .any(|e| e.attribute_equality(Attribute::Uuid, &PVUUID_DOMAIN_INFO))
252        {
253            self.changed_flags.insert(ChangeFlag::DOMAIN)
254        }
255
256        if !self.changed_flags.contains(ChangeFlag::SYSTEM_CONFIG)
257            && norm_cand
258                .iter()
259                .chain(pre_candidates.iter().map(|e| e.as_ref()))
260                .any(|e| e.attribute_equality(Attribute::Uuid, &PVUUID_SYSTEM_CONFIG))
261        {
262            self.changed_flags.insert(ChangeFlag::SYSTEM_CONFIG)
263        }
264
265        if !self.changed_flags.contains(ChangeFlag::SYNC_AGREEMENT)
266            && norm_cand
267                .iter()
268                .chain(pre_candidates.iter().map(|e| e.as_ref()))
269                .any(|e| e.attribute_equality(Attribute::Class, &EntryClass::SyncAccount.into()))
270        {
271            self.changed_flags.insert(ChangeFlag::SYNC_AGREEMENT)
272        }
273
274        if !self.changed_flags.contains(ChangeFlag::KEY_MATERIAL)
275            && norm_cand
276                .iter()
277                .chain(pre_candidates.iter().map(|e| e.as_ref()))
278                .any(|e| {
279                    e.attribute_equality(Attribute::Class, &EntryClass::KeyProvider.into())
280                        || e.attribute_equality(Attribute::Class, &EntryClass::KeyObject.into())
281                })
282        {
283            self.changed_flags.insert(ChangeFlag::KEY_MATERIAL)
284        }
285
286        self.changed_uuid.extend(
287            norm_cand
288                .iter()
289                .map(|e| e.get_uuid())
290                .chain(pre_candidates.iter().map(|e| e.get_uuid())),
291        );
292
293        trace!(
294            changed = ?self.changed_flags.iter_names().collect::<Vec<_>>(),
295        );
296
297        // return
298        if me.ident.is_internal() {
299            trace!("Modify operation success");
300        } else {
301            admin_info!("Modify operation success");
302        }
303        Ok(())
304    }
305
306    pub fn internal_batch_modify(
307        &mut self,
308        mods_iter: impl Iterator<Item = (Uuid, ModifyList<ModifyInvalid>)>,
309    ) -> Result<(), OperationError> {
310        let modset = mods_iter
311            .map(|(u, ml)| {
312                ml.validate(self.get_schema())
313                    .map(|modlist| (u, modlist))
314                    .map_err(OperationError::SchemaViolation)
315            })
316            .collect::<Result<ModSetValid, _>>()?;
317        let bme = BatchModifyEvent {
318            ident: Identity::from_internal(),
319            modset,
320        };
321        self.batch_modify(&bme)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use crate::prelude::*;
328
329    #[qs_test]
330    async fn test_batch_modify_basic(server: &QueryServer) {
331        let mut server_txn = server.write(duration_from_epoch_now()).await.unwrap();
332        // Setup entries.
333        let uuid_a = Uuid::new_v4();
334        let uuid_b = Uuid::new_v4();
335        assert!(server_txn
336            .internal_create(vec![
337                entry_init!(
338                    (Attribute::Class, EntryClass::Object.to_value()),
339                    (Attribute::Uuid, Value::Uuid(uuid_a))
340                ),
341                entry_init!(
342                    (Attribute::Class, EntryClass::Object.to_value()),
343                    (Attribute::Uuid, Value::Uuid(uuid_b))
344                ),
345            ])
346            .is_ok());
347
348        // Do a batch mod.
349        assert!(server_txn
350            .internal_batch_modify(
351                [
352                    (
353                        uuid_a,
354                        ModifyList::new_append(Attribute::Description, Value::Utf8("a".into()))
355                    ),
356                    (
357                        uuid_b,
358                        ModifyList::new_append(Attribute::Description, Value::Utf8("b".into()))
359                    ),
360                ]
361                .into_iter()
362            )
363            .is_ok());
364
365        // Now check them
366        let ent_a = server_txn
367            .internal_search_uuid(uuid_a)
368            .expect("Failed to get entry.");
369        let ent_b = server_txn
370            .internal_search_uuid(uuid_b)
371            .expect("Failed to get entry.");
372
373        assert_eq!(ent_a.get_ava_single_utf8(Attribute::Description), Some("a"));
374        assert_eq!(ent_b.get_ava_single_utf8(Attribute::Description), Some("b"));
375    }
376}