kanidm_unix_resolver/
db.rs

1use crate::idprovider::interface::{GroupToken, Id, UserToken};
2use async_trait::async_trait;
3use kanidm_hsm_crypto::structures::{LoadableHmacS256Key, LoadableStorageKey};
4use libc::umask;
5use rusqlite::{Connection, OptionalExtension};
6use serde::{de::DeserializeOwned, Serialize};
7use std::convert::TryFrom;
8use std::fmt;
9use tokio::sync::{Mutex, MutexGuard};
10use uuid::Uuid;
11
12const DBV_MAIN: &str = "main";
13// This is in *pages* for sqlite. The default page size is 4096 bytes. So to achieve
14// 32MB we need to divide by this.
15const CACHE_SIZE: usize = 32 * ((1024 * 1024) / 4096);
16
17#[async_trait]
18pub trait Cache {
19    type Txn<'db>
20    where
21        Self: 'db;
22
23    async fn write<'db>(&'db self) -> Self::Txn<'db>;
24}
25
26#[async_trait]
27pub trait KeyStore {
28    type Txn<'db>
29    where
30        Self: 'db;
31
32    async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
33}
34
35#[derive(Debug)]
36pub enum CacheError {
37    Cryptography,
38    SerdeJson,
39    Parse,
40    Sqlite,
41    TooManyResults,
42    TransactionInvalidState,
43    Tpm,
44}
45
46pub struct Db {
47    conn: Mutex<Connection>,
48}
49
50pub struct DbTxn<'a> {
51    conn: MutexGuard<'a, Connection>,
52    committed: bool,
53}
54
55pub struct KeyStoreTxn<'a, 'b> {
56    db: &'b mut DbTxn<'a>,
57}
58
59impl<'a, 'b> From<&'b mut DbTxn<'a>> for KeyStoreTxn<'a, 'b> {
60    fn from(db: &'b mut DbTxn<'a>) -> Self {
61        Self { db }
62    }
63}
64
65#[derive(Debug)]
66/// Errors coming back from the `Db` struct
67pub enum DbError {
68    Sqlite,
69    Tpm,
70}
71
72impl Db {
73    pub fn new(path: &str) -> Result<Self, DbError> {
74        let before = unsafe { umask(0o0027) };
75        let conn = Connection::open(path).map_err(|e| {
76            error!(err = ?e, "rusqulite error");
77            DbError::Sqlite
78        })?;
79        let _ = unsafe { umask(before) };
80
81        // Setup WAL/COW mode.
82        conn.pragma_update(None, "journal_mode", "WAL")
83            .map_err(|error| {
84                error!(
85                    "sqlite journal_mode=WAL error: {:?} db_path={:?}",
86                    error, path
87                );
88                DbError::Sqlite
89            })?;
90
91        // synchronous=normal is safe for WAL
92        conn.pragma_update(None, "synchronous", "NORMAL")
93            .map_err(|error| {
94                error!(
95                    "sqlite synchronous=NORMAL error: {:?} db_path={:?}",
96                    error, path
97                );
98                DbError::Sqlite
99            })?;
100
101        conn.pragma_update(None, "cache_size", CACHE_SIZE)
102            .map_err(|error| {
103                error!(
104                    "sqlite cache_size={} error: {:?} db_path={:?}",
105                    CACHE_SIZE, error, path
106                );
107                DbError::Sqlite
108            })?;
109
110        conn.set_prepared_statement_cache_capacity(32);
111
112        Ok(Db {
113            conn: Mutex::new(conn),
114        })
115    }
116}
117
118#[async_trait]
119impl Cache for Db {
120    type Txn<'db> = DbTxn<'db>;
121
122    #[allow(clippy::expect_used)]
123    async fn write<'db>(&'db self) -> Self::Txn<'db> {
124        let conn = self.conn.lock().await;
125        DbTxn::new(conn)
126    }
127}
128
129impl fmt::Debug for Db {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        write!(f, "Db {{}}")
132    }
133}
134
135impl<'a> DbTxn<'a> {
136    fn new(conn: MutexGuard<'a, Connection>) -> Self {
137        // Start the transaction
138        // trace!("Starting db WR txn ...");
139        #[allow(clippy::expect_used)]
140        conn.execute("BEGIN TRANSACTION", [])
141            .expect("Unable to begin transaction!");
142        DbTxn {
143            committed: false,
144            conn,
145        }
146    }
147
148    /// This handles an error coming back from an sqlite event and dumps more information from it
149    fn sqlite_error(&self, msg: &str, error: &rusqlite::Error) -> CacheError {
150        error!(
151            "sqlite {} error: {:?} db_path={:?}",
152            msg,
153            error,
154            &self.conn.path()
155        );
156        CacheError::Sqlite
157    }
158
159    /// This handles an error coming back from an sqlite transaction and dumps a load of information from it
160    fn sqlite_transaction_error(
161        &self,
162        error: &rusqlite::Error,
163        _stmt: &rusqlite::Statement,
164    ) -> CacheError {
165        error!(
166            "sqlite transaction error={:?} db_path={:?}",
167            error,
168            &self.conn.path(),
169        );
170        // TODO: one day figure out if there's an easy way to dump the transaction without the token...
171        CacheError::Sqlite
172    }
173
174    fn get_db_version(&self, key: &str) -> i64 {
175        self.conn
176            .query_row(
177                "SELECT version FROM db_version_t WHERE id = :id",
178                &[(":id", key)],
179                |row| row.get(0),
180            )
181            .unwrap_or({
182                // The value is missing, default to 0.
183                0
184            })
185    }
186
187    fn set_db_version(&self, key: &str, v: i64) -> Result<(), CacheError> {
188        self.conn
189            .execute(
190                "INSERT OR REPLACE INTO db_version_t (id, version) VALUES(:id, :dbv)",
191                named_params! {
192                    ":id": &key,
193                    ":dbv": v,
194                },
195            )
196            .map(|_| ())
197            .map_err(|e| self.sqlite_error("set db_version_t", &e))
198    }
199
200    fn get_account_data_name(
201        &mut self,
202        account_id: &str,
203    ) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
204        let mut stmt = self.conn
205            .prepare(
206        "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
207            )
208            .map_err(|e| {
209                self.sqlite_error("select prepare", &e)
210            })?;
211
212        // Makes tuple (token, expiry)
213        let data_iter = stmt
214            .query_map([account_id], |row| Ok((row.get(0)?, row.get(1)?)))
215            .map_err(|e| self.sqlite_error("query_map failure", &e))?;
216        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
217            .map(|v| v.map_err(|e| self.sqlite_error("map failure", &e)))
218            .collect();
219        data
220    }
221
222    fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
223        let mut stmt = self
224            .conn
225            .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
226            .map_err(|e| self.sqlite_error("select prepare", &e))?;
227
228        // Makes tuple (token, expiry)
229        let data_iter = stmt
230            .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
231            .map_err(|e| self.sqlite_error("query_map", &e))?;
232        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
233            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
234            .collect();
235        data
236    }
237
238    fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
239        let mut stmt = self.conn
240            .prepare(
241                "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
242            )
243            .map_err(|e| {
244                self.sqlite_error("select prepare", &e)
245            })?;
246
247        // Makes tuple (token, expiry)
248        let data_iter = stmt
249            .query_map([grp_id], |row| Ok((row.get(0)?, row.get(1)?)))
250            .map_err(|e| self.sqlite_error("query_map", &e))?;
251        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
252            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
253            .collect();
254        data
255    }
256
257    fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
258        let mut stmt = self
259            .conn
260            .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
261            .map_err(|e| self.sqlite_error("select prepare", &e))?;
262
263        // Makes tuple (token, expiry)
264        let data_iter = stmt
265            .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
266            .map_err(|e| self.sqlite_error("query_map", &e))?;
267        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
268            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
269            .collect();
270        data
271    }
272}
273
274impl KeyStoreTxn<'_, '_> {
275    pub fn get_tagged_hsm_key<K: DeserializeOwned>(
276        &mut self,
277        tag: &str,
278    ) -> Result<Option<K>, CacheError> {
279        self.db.get_tagged_hsm_key(tag)
280    }
281
282    pub fn insert_tagged_hsm_key<K: Serialize>(
283        &mut self,
284        tag: &str,
285        key: &K,
286    ) -> Result<(), CacheError> {
287        self.db.insert_tagged_hsm_key(tag, key)
288    }
289
290    pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
291        self.db.delete_tagged_hsm_key(tag)
292    }
293}
294
295impl DbTxn<'_> {
296    fn get_tagged_hsm_key<K: DeserializeOwned>(
297        &mut self,
298        tag: &str,
299    ) -> Result<Option<K>, CacheError> {
300        let mut stmt = self
301            .conn
302            .prepare("SELECT value FROM hsm_data_t WHERE key = :key")
303            .map_err(|e| self.sqlite_error("select prepare", &e))?;
304
305        let data: Option<Vec<u8>> = stmt
306            .query_row(
307                named_params! {
308                    ":key": tag
309                },
310                |row| row.get(0),
311            )
312            .optional()
313            .map_err(|e| self.sqlite_error("query_row", &e))?;
314
315        match data {
316            Some(d) => Ok(serde_json::from_slice(d.as_slice())
317                .map_err(|e| {
318                    error!("json error -> {:?}", e);
319                })
320                .ok()),
321            None => Ok(None),
322        }
323    }
324
325    fn insert_tagged_hsm_key<K: Serialize>(
326        &mut self,
327        tag: &str,
328        key: &K,
329    ) -> Result<(), CacheError> {
330        let data = serde_json::to_vec(key).map_err(|e| {
331            error!("json error -> {:?}", e);
332            CacheError::SerdeJson
333        })?;
334
335        let mut stmt = self
336            .conn
337            .prepare("INSERT OR REPLACE INTO hsm_data_t (key, value) VALUES (:key, :value)")
338            .map_err(|e| self.sqlite_error("prepare", &e))?;
339
340        stmt.execute(named_params! {
341            ":key": tag,
342            ":value": &data,
343        })
344        .map(|r| {
345            trace!("insert -> {:?}", r);
346        })
347        .map_err(|e| self.sqlite_error("execute", &e))
348    }
349
350    fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
351        self.conn
352            .execute(
353                "DELETE FROM hsm_data_t where key = :key",
354                named_params! {
355                    ":key": tag,
356                },
357            )
358            .map(|_| ())
359            .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))
360    }
361}
362
363impl DbTxn<'_> {
364    pub fn migrate(&mut self) -> Result<(), CacheError> {
365        // This definition can never change.
366        self.conn
367            .execute(
368                "CREATE TABLE IF NOT EXISTS db_version_t (
369                    id TEXT PRIMARY KEY,
370                    version INTEGER
371                )",
372                [],
373            )
374            .map_err(|e| self.sqlite_error("db_version_t create", &e))?;
375
376        let db_version = self.get_db_version(DBV_MAIN);
377
378        if db_version < 1 {
379            // Setup two tables - one for accounts, one for groups.
380            // correctly index the columns.
381            // Optional pw hash field
382            self.conn
383                .execute(
384                    "CREATE TABLE IF NOT EXISTS account_t (
385                    uuid TEXT PRIMARY KEY,
386                    name TEXT NOT NULL UNIQUE,
387                    spn TEXT NOT NULL UNIQUE,
388                    gidnumber INTEGER NOT NULL UNIQUE,
389                    password BLOB,
390                    token BLOB NOT NULL,
391                    expiry NUMERIC NOT NULL
392                )
393                ",
394                    [],
395                )
396                .map_err(|e| self.sqlite_error("account_t create", &e))?;
397
398            self.conn
399                .execute(
400                    "CREATE TABLE IF NOT EXISTS group_t (
401                    uuid TEXT PRIMARY KEY,
402                    name TEXT NOT NULL UNIQUE,
403                    spn TEXT NOT NULL UNIQUE,
404                    gidnumber INTEGER NOT NULL UNIQUE,
405                    token BLOB NOT NULL,
406                    expiry NUMERIC NOT NULL
407                )
408                ",
409                    [],
410                )
411                .map_err(|e| self.sqlite_error("group_t create", &e))?;
412
413            // We defer group foreign keys here because we now manually cascade delete these when
414            // required. This is because insert or replace into will always delete then add
415            // which triggers this. So instead we defer and manually cascade.
416            //
417            // However, on accounts, we CAN delete cascade because accounts will always redefine
418            // their memberships on updates so this is safe to cascade on this direction.
419            self.conn
420                .execute(
421                    "CREATE TABLE IF NOT EXISTS memberof_t (
422                    g_uuid TEXT,
423                    a_uuid TEXT,
424                    FOREIGN KEY(g_uuid) REFERENCES group_t(uuid) DEFERRABLE INITIALLY DEFERRED,
425                    FOREIGN KEY(a_uuid) REFERENCES account_t(uuid) ON DELETE CASCADE
426                )
427                ",
428                    [],
429                )
430                .map_err(|e| self.sqlite_error("memberof_t create error", &e))?;
431
432            // Create the hsm_data store. These are all generally encrypted private
433            // keys, and the hsm structures will decrypt these as required.
434            self.conn
435                .execute(
436                    "CREATE TABLE IF NOT EXISTS hsm_int_t (
437                        key TEXT PRIMARY KEY,
438                        value BLOB NOT NULL
439                    )
440                    ",
441                    [],
442                )
443                .map_err(|e| self.sqlite_error("hsm_int_t create error", &e))?;
444
445            self.conn
446                .execute(
447                    "CREATE TABLE IF NOT EXISTS hsm_data_t (
448                        key TEXT PRIMARY KEY,
449                        value BLOB NOT NULL
450                    )
451                    ",
452                    [],
453                )
454                .map_err(|e| self.sqlite_error("hsm_data_t create error", &e))?;
455
456            // Since this is the 0th migration, we have to reset the HSM here.
457            self.clear_hsm()?;
458        }
459
460        if db_version < 2 {
461            self.conn
462                .execute(
463                    "CREATE INDEX IF NOT EXISTS account_t_uuid_idx ON account_t ( uuid )",
464                    [],
465                )
466                .map_err(|e| self.sqlite_error("account_t uuid index create", &e))?;
467
468            self.conn
469                .execute(
470                    "CREATE INDEX IF NOT EXISTS account_t_name_idx ON account_t ( name )",
471                    [],
472                )
473                .map_err(|e| self.sqlite_error("account_t name index create", &e))?;
474
475            self.conn
476                .execute(
477                    "CREATE INDEX IF NOT EXISTS account_t_spn_idx ON account_t ( spn )",
478                    [],
479                )
480                .map_err(|e| self.sqlite_error("account_t spn index create", &e))?;
481
482            self.conn
483                .execute(
484                    "CREATE INDEX IF NOT EXISTS account_t_gidnumber_idx ON account_t ( gidnumber )",
485                    [],
486                )
487                .map_err(|e| self.sqlite_error("account_t gidnumber index create", &e))?;
488
489            self.conn
490                .execute(
491                    "CREATE INDEX IF NOT EXISTS group_t_uuid_idx ON group_t ( uuid )",
492                    [],
493                )
494                .map_err(|e| self.sqlite_error("group_t uuid index create", &e))?;
495
496            self.conn
497                .execute(
498                    "CREATE INDEX IF NOT EXISTS group_t_name_idx ON group_t ( name )",
499                    [],
500                )
501                .map_err(|e| self.sqlite_error("group_t name index create", &e))?;
502
503            self.conn
504                .execute(
505                    "CREATE INDEX IF NOT EXISTS group_t_spn_idx ON group_t ( spn )",
506                    [],
507                )
508                .map_err(|e| self.sqlite_error("group_t spn index create", &e))?;
509
510            self.conn
511                .execute(
512                    "CREATE INDEX IF NOT EXISTS group_t_gidnumber_idx ON group_t ( gidnumber )",
513                    [],
514                )
515                .map_err(|e| self.sqlite_error("group_t gidnumber index create", &e))?;
516
517            self.conn
518                .execute(
519                    "CREATE INDEX IF NOT EXISTS memberof_t_g_uuid_idx ON memberof_t ( g_uuid )",
520                    [],
521                )
522                .map_err(|e| self.sqlite_error("memberof_t g_uuid index create", &e))?;
523
524            self.conn
525                .execute(
526                    "CREATE INDEX IF NOT EXISTS memberof_t_a_uuid_idx ON memberof_t ( a_uuid )",
527                    [],
528                )
529                .map_err(|e| self.sqlite_error("memberof_t a_uuid index create", &e))?;
530        }
531
532        self.set_db_version(DBV_MAIN, 2)?;
533
534        Ok(())
535    }
536
537    #[instrument(level = "debug", skip_all)]
538    pub fn commit(mut self) -> Result<(), CacheError> {
539        if self.committed {
540            error!("Invalid state, SQL transaction was already committed!");
541            return Err(CacheError::TransactionInvalidState);
542        }
543        self.committed = true;
544
545        self.conn
546            .execute("COMMIT TRANSACTION", [])
547            .map(|_| ())
548            .map_err(|e| self.sqlite_error("commit", &e))
549    }
550
551    pub fn invalidate(&mut self) -> Result<(), CacheError> {
552        self.conn
553            .execute("UPDATE group_t SET expiry = 0", [])
554            .map_err(|e| self.sqlite_error("update group_t", &e))?;
555
556        self.conn
557            .execute("UPDATE account_t SET expiry = 0", [])
558            .map_err(|e| self.sqlite_error("update account_t", &e))?;
559
560        Ok(())
561    }
562
563    pub fn clear(&mut self) -> Result<(), CacheError> {
564        self.conn
565            .execute("DELETE FROM memberof_t", [])
566            .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
567
568        self.conn
569            .execute("DELETE FROM group_t", [])
570            .map_err(|e| self.sqlite_error("delete group_t", &e))?;
571
572        self.conn
573            .execute("DELETE FROM account_t", [])
574            .map_err(|e| self.sqlite_error("delete group_t", &e))?;
575
576        Ok(())
577    }
578
579    pub fn clear_hsm(&mut self) -> Result<(), CacheError> {
580        self.clear()?;
581
582        self.conn
583            .execute("DELETE FROM hsm_int_t", [])
584            .map_err(|e| self.sqlite_error("delete hsm_int_t", &e))?;
585
586        self.conn
587            .execute("DELETE FROM hsm_data_t", [])
588            .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))?;
589
590        Ok(())
591    }
592
593    pub fn get_hsm_root_storage_key(&mut self) -> Result<Option<LoadableStorageKey>, CacheError> {
594        let mut stmt = self
595            .conn
596            .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
597            .map_err(|e| self.sqlite_error("select prepare", &e))?;
598
599        let data: Option<Vec<u8>> = stmt
600            .query_row([], |row| row.get(0))
601            .optional()
602            .map_err(|e| self.sqlite_error("query_row", &e))?;
603
604        match data {
605            Some(d) => Ok(serde_json::from_slice(d.as_slice())
606                .map_err(|e| {
607                    error!("json error -> {:?}", e);
608                })
609                .ok()),
610            None => Ok(None),
611        }
612    }
613
614    pub fn insert_hsm_root_storage_key(
615        &mut self,
616        machine_key: &LoadableStorageKey,
617    ) -> Result<(), CacheError> {
618        let data = serde_json::to_vec(machine_key).map_err(|e| {
619            error!("insert_hsm_machine_key json error -> {:?}", e);
620            CacheError::SerdeJson
621        })?;
622
623        let mut stmt = self
624            .conn
625            .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
626            .map_err(|e| self.sqlite_error("prepare", &e))?;
627
628        stmt.execute(named_params! {
629            ":key": "mk",
630            ":value": &data,
631        })
632        .map(|r| {
633            trace!("insert -> {:?}", r);
634        })
635        .map_err(|e| self.sqlite_error("execute", &e))
636    }
637
638    pub fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacS256Key>, CacheError> {
639        let mut stmt = self
640            .conn
641            .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
642            .map_err(|e| self.sqlite_error("select prepare", &e))?;
643
644        let data: Option<Vec<u8>> = stmt
645            .query_row([], |row| row.get(0))
646            .optional()
647            .map_err(|e| self.sqlite_error("query_row", &e))?;
648
649        match data {
650            Some(d) => Ok(serde_json::from_slice(d.as_slice())
651                .map_err(|e| {
652                    error!("json error -> {:?}", e);
653                })
654                .ok()),
655            None => Ok(None),
656        }
657    }
658
659    pub fn insert_hsm_hmac_key(
660        &mut self,
661        hmac_key: &LoadableHmacS256Key,
662    ) -> Result<(), CacheError> {
663        let data = serde_json::to_vec(hmac_key).map_err(|e| {
664            error!("insert_hsm_hmac_key json error -> {:?}", e);
665            CacheError::SerdeJson
666        })?;
667
668        let mut stmt = self
669            .conn
670            .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
671            .map_err(|e| self.sqlite_error("prepare", &e))?;
672
673        stmt.execute(named_params! {
674            ":key": "hmac",
675            ":value": &data,
676        })
677        .map(|r| {
678            trace!("insert -> {:?}", r);
679        })
680        .map_err(|e| self.sqlite_error("execute", &e))
681    }
682
683    #[instrument(level = "debug", skip_all)]
684    pub fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
685        let data = match account_id {
686            Id::Name(n) => self.get_account_data_name(n.as_str()),
687            Id::Gid(g) => self.get_account_data_gid(*g),
688        }?;
689
690        // Assert only one result?
691        if data.len() >= 2 {
692            error!("invalid db state, multiple entries matched query?");
693            return Err(CacheError::TooManyResults);
694        }
695
696        if let Some((token, expiry)) = data.first() {
697            // token convert with json.
698            // If this errors, we specifically return Ok(None) because that triggers
699            // the cache to refetch the token.
700            match serde_json::from_slice(token.as_slice()) {
701                Ok(t) => {
702                    let e = u64::try_from(*expiry).map_err(|e| {
703                        error!("u64 convert error -> {:?}", e);
704                        CacheError::Parse
705                    })?;
706                    Ok(Some((t, e)))
707                }
708                Err(e) => {
709                    warn!("recoverable - json error -> {:?}", e);
710                    Ok(None)
711                }
712            }
713        } else {
714            Ok(None)
715        }
716    }
717
718    #[instrument(level = "debug", skip_all)]
719    pub fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
720        let mut stmt = self
721            .conn
722            .prepare("SELECT token FROM account_t")
723            .map_err(|e| self.sqlite_error("select prepare", &e))?;
724
725        let data_iter = stmt
726            .query_map([], |row| row.get(0))
727            .map_err(|e| self.sqlite_error("query_map", &e))?;
728        let data: Result<Vec<Vec<u8>>, _> = data_iter
729            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
730            .collect();
731
732        let data = data?;
733
734        Ok(data
735            .iter()
736            // We filter map here so that anything invalid is skipped.
737            .filter_map(|token| {
738                // token convert with json.
739                serde_json::from_slice(token.as_slice())
740                    .map_err(|e| {
741                        warn!("get_accounts json error -> {:?}", e);
742                    })
743                    .ok()
744            })
745            .collect())
746    }
747
748    #[instrument(level = "debug", skip_all)]
749    pub fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
750        let data = serde_json::to_vec(account).map_err(|e| {
751            error!("update_account json error -> {:?}", e);
752            CacheError::SerdeJson
753        })?;
754        let expire = i64::try_from(expire).map_err(|e| {
755            error!("update_account i64 conversion error -> {:?}", e);
756            CacheError::Parse
757        })?;
758
759        // This is needed because sqlites 'insert or replace into', will null the password field
760        // if present, and upsert MUST match the exact conflicting column, so that means we have
761        // to manually manage the update or insert :( :(
762        let account_uuid = account.uuid.as_hyphenated().to_string();
763
764        // Find anything conflicting and purge it.
765        self.conn.execute("DELETE FROM account_t WHERE NOT uuid = :uuid AND (name = :name OR spn = :spn OR gidnumber = :gidnumber)",
766            named_params!{
767                ":uuid": &account_uuid,
768                ":name": &account.name,
769                ":spn": &account.spn,
770                ":gidnumber": &account.gidnumber,
771            }
772            )
773            .map_err(|e| {
774                self.sqlite_error("delete account_t duplicate", &e)
775            })
776            .map(|_| ())?;
777
778        let updated = self.conn.execute(
779                "UPDATE account_t SET name=:name, spn=:spn, gidnumber=:gidnumber, token=:token, expiry=:expiry WHERE uuid = :uuid",
780            named_params!{
781                ":uuid": &account_uuid,
782                ":name": &account.name,
783                ":spn": &account.spn,
784                ":gidnumber": &account.gidnumber,
785                ":token": &data,
786                ":expiry": &expire,
787            }
788            )
789            .map_err(|e| {
790                self.sqlite_error("delete account_t duplicate", &e)
791            })?;
792
793        if updated == 0 {
794            let mut stmt = self.conn
795                .prepare("INSERT INTO account_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry) ON CONFLICT(uuid) DO UPDATE SET name=excluded.name, spn=excluded.name, gidnumber=excluded.gidnumber, token=excluded.token, expiry=excluded.expiry")
796                .map_err(|e| {
797                    self.sqlite_error("prepare", &e)
798                })?;
799
800            stmt.execute(named_params! {
801                ":uuid": &account_uuid,
802                ":name": &account.name,
803                ":spn": &account.spn,
804                ":gidnumber": &account.gidnumber,
805                ":token": &data,
806                ":expiry": &expire,
807            })
808            .map(|r| {
809                trace!("insert -> {:?}", r);
810            })
811            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
812        }
813
814        // Now, we have to update the group memberships.
815
816        // First remove everything that already exists:
817        let mut stmt = self
818            .conn
819            .prepare("DELETE FROM memberof_t WHERE a_uuid = :a_uuid")
820            .map_err(|e| self.sqlite_error("prepare", &e))?;
821
822        stmt.execute([&account_uuid])
823            .map(|r| {
824                trace!("delete memberships -> {:?}", r);
825            })
826            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
827
828        let mut stmt = self
829            .conn
830            .prepare("INSERT INTO memberof_t (a_uuid, g_uuid) VALUES (:a_uuid, :g_uuid)")
831            .map_err(|e| self.sqlite_error("prepare", &e))?;
832        // Now for each group, add the relation.
833        account.groups.iter().try_for_each(|g| {
834            stmt.execute(named_params! {
835                ":a_uuid": &account_uuid,
836                ":g_uuid": &g.uuid.as_hyphenated().to_string(),
837            })
838            .map(|r| {
839                trace!("insert membership -> {:?}", r);
840            })
841            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))
842        })
843    }
844
845    #[instrument(level = "debug", skip_all)]
846    pub fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
847        let account_uuid = a_uuid.as_hyphenated().to_string();
848
849        self.conn
850            .execute(
851                "DELETE FROM memberof_t WHERE a_uuid = :a_uuid",
852                params![&account_uuid],
853            )
854            .map(|_| ())
855            .map_err(|e| self.sqlite_error("account_t memberof_t cascade delete", &e))?;
856
857        self.conn
858            .execute(
859                "DELETE FROM account_t WHERE uuid = :a_uuid",
860                params![&account_uuid],
861            )
862            .map(|_| ())
863            .map_err(|e| self.sqlite_error("account_t delete", &e))
864    }
865
866    #[instrument(level = "debug", skip_all)]
867    pub fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
868        let data = match grp_id {
869            Id::Name(n) => self.get_group_data_name(n.as_str()),
870            Id::Gid(g) => self.get_group_data_gid(*g),
871        }?;
872
873        // Assert only one result?
874        if data.len() >= 2 {
875            error!("invalid db state, multiple entries matched query?");
876            return Err(CacheError::TooManyResults);
877        }
878
879        if let Some((token, expiry)) = data.first() {
880            // token convert with json.
881            // If this errors, we specifically return Ok(None) because that triggers
882            // the cache to refetch the token.
883            match serde_json::from_slice(token.as_slice()) {
884                Ok(t) => {
885                    let e = u64::try_from(*expiry).map_err(|e| {
886                        error!("u64 convert error -> {:?}", e);
887                        CacheError::Parse
888                    })?;
889                    Ok(Some((t, e)))
890                }
891                Err(e) => {
892                    warn!("recoverable - json error -> {:?}", e);
893                    Ok(None)
894                }
895            }
896        } else {
897            Ok(None)
898        }
899    }
900
901    #[instrument(level = "debug", skip_all)]
902    pub fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
903        let mut stmt = self
904            .conn
905            .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
906            .map_err(|e| {
907                self.sqlite_error("select prepare", &e)
908            })?;
909
910        let data_iter = stmt
911            .query_map([g_uuid.as_hyphenated().to_string()], |row| row.get(0))
912            .map_err(|e| self.sqlite_error("query_map", &e))?;
913        let data: Result<Vec<Vec<u8>>, _> = data_iter
914            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
915            .collect();
916
917        let data = data?;
918
919        data.iter()
920            .map(|token| {
921                // token convert with json.
922                // trace!("{:?}", token);
923                serde_json::from_slice(token.as_slice()).map_err(|e| {
924                    error!("json error -> {:?}", e);
925                    CacheError::SerdeJson
926                })
927            })
928            .collect()
929    }
930
931    #[instrument(level = "debug", skip_all)]
932    pub fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
933        let mut stmt = self
934            .conn
935            .prepare("SELECT token FROM group_t")
936            .map_err(|e| self.sqlite_error("select prepare", &e))?;
937
938        let data_iter = stmt
939            .query_map([], |row| row.get(0))
940            .map_err(|e| self.sqlite_error("query_map", &e))?;
941        let data: Result<Vec<Vec<u8>>, _> = data_iter
942            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
943            .collect();
944
945        let data = data?;
946
947        Ok(data
948            .iter()
949            .filter_map(|token| {
950                // token convert with json.
951                // trace!("{:?}", token);
952                serde_json::from_slice(token.as_slice())
953                    .map_err(|e| {
954                        error!("json error -> {:?}", e);
955                    })
956                    .ok()
957            })
958            .collect())
959    }
960
961    #[instrument(level = "debug", skip_all)]
962    pub fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
963        let data = serde_json::to_vec(grp).map_err(|e| {
964            error!("json error -> {:?}", e);
965            CacheError::SerdeJson
966        })?;
967        let expire = i64::try_from(expire).map_err(|e| {
968            error!("i64 convert error -> {:?}", e);
969            CacheError::Parse
970        })?;
971
972        let mut stmt = self.conn
973            .prepare("INSERT OR REPLACE INTO group_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry)")
974            .map_err(|e| {
975                self.sqlite_error("prepare", &e)
976            })?;
977
978        // We have to to-str uuid as the sqlite impl makes it a blob which breaks our selects in get.
979        stmt.execute(named_params! {
980            ":uuid": &grp.uuid.as_hyphenated().to_string(),
981            ":name": &grp.name,
982            ":spn": &grp.spn,
983            ":gidnumber": &grp.gidnumber,
984            ":token": &data,
985            ":expiry": &expire,
986        })
987        .map(|r| {
988            trace!("insert -> {:?}", r);
989        })
990        .map_err(|e| self.sqlite_error("execute", &e))
991    }
992
993    #[instrument(level = "debug", skip_all)]
994    pub fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
995        let group_uuid = g_uuid.as_hyphenated().to_string();
996        self.conn
997            .execute(
998                "DELETE FROM memberof_t WHERE g_uuid = :g_uuid",
999                [&group_uuid],
1000            )
1001            .map(|_| ())
1002            .map_err(|e| self.sqlite_error("group_t memberof_t cascade delete", &e))?;
1003        self.conn
1004            .execute("DELETE FROM group_t WHERE uuid = :g_uuid", [&group_uuid])
1005            .map(|_| ())
1006            .map_err(|e| self.sqlite_error("group_t delete", &e))
1007    }
1008}
1009
1010impl fmt::Debug for DbTxn<'_> {
1011    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1012        write!(f, "DbTxn {{}}")
1013    }
1014}
1015
1016impl Drop for DbTxn<'_> {
1017    // Abort
1018    fn drop(&mut self) {
1019        if !self.committed {
1020            // trace!("Aborting BE WR txn");
1021            #[allow(clippy::expect_used)]
1022            self.conn
1023                .execute("ROLLBACK TRANSACTION", [])
1024                .expect("Unable to rollback transaction! Can not proceed!!!");
1025        }
1026    }
1027}
1028
1029#[cfg(test)]
1030mod tests {
1031    use super::{Cache, Db};
1032    use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
1033
1034    #[tokio::test]
1035    async fn test_cache_db_account_basic() {
1036        sketching::test_init();
1037        let db = Db::new("").expect("failed to create.");
1038        let mut dbtxn = db.write().await;
1039        assert!(dbtxn.migrate().is_ok());
1040
1041        let mut ut1 = UserToken {
1042            provider: ProviderOrigin::System,
1043            name: "testuser".to_string(),
1044            spn: "testuser@example.com".to_string(),
1045            displayname: "Test User".to_string(),
1046            gidnumber: 2000,
1047            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1048            shell: None,
1049            groups: Vec::new(),
1050            sshkeys: vec!["key-a".to_string()],
1051            valid: true,
1052            extra_keys: Default::default(),
1053        };
1054
1055        let id_name = Id::Name("testuser".to_string());
1056        let id_name2 = Id::Name("testuser2".to_string());
1057        let id_spn = Id::Name("testuser@example.com".to_string());
1058        let id_spn2 = Id::Name("testuser2@example.com".to_string());
1059        let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1060        let id_gid = Id::Gid(2000);
1061
1062        // test finding no account
1063        let r1 = dbtxn.get_account(&id_name).unwrap();
1064        assert!(r1.is_none());
1065        let r2 = dbtxn.get_account(&id_spn).unwrap();
1066        assert!(r2.is_none());
1067        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1068        assert!(r3.is_none());
1069        let r4 = dbtxn.get_account(&id_gid).unwrap();
1070        assert!(r4.is_none());
1071
1072        // test adding an account
1073        dbtxn.update_account(&ut1, 0).unwrap();
1074
1075        // test we can get it.
1076        let r1 = dbtxn.get_account(&id_name).unwrap();
1077        assert!(r1.is_some());
1078        let r2 = dbtxn.get_account(&id_spn).unwrap();
1079        assert!(r2.is_some());
1080        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1081        assert!(r3.is_some());
1082        let r4 = dbtxn.get_account(&id_gid).unwrap();
1083        assert!(r4.is_some());
1084
1085        // test adding an account that was renamed
1086        ut1.name = "testuser2".to_string();
1087        ut1.spn = "testuser2@example.com".to_string();
1088        dbtxn.update_account(&ut1, 0).unwrap();
1089
1090        // get the account
1091        let r1 = dbtxn.get_account(&id_name).unwrap();
1092        assert!(r1.is_none());
1093        let r2 = dbtxn.get_account(&id_spn).unwrap();
1094        assert!(r2.is_none());
1095        let r1 = dbtxn.get_account(&id_name2).unwrap();
1096        assert!(r1.is_some());
1097        let r2 = dbtxn.get_account(&id_spn2).unwrap();
1098        assert!(r2.is_some());
1099        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1100        assert!(r3.is_some());
1101        let r4 = dbtxn.get_account(&id_gid).unwrap();
1102        assert!(r4.is_some());
1103
1104        // Clear cache
1105        assert!(dbtxn.clear().is_ok());
1106
1107        // should be nothing
1108        let r1 = dbtxn.get_account(&id_name2).unwrap();
1109        assert!(r1.is_none());
1110        let r2 = dbtxn.get_account(&id_spn2).unwrap();
1111        assert!(r2.is_none());
1112        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1113        assert!(r3.is_none());
1114        let r4 = dbtxn.get_account(&id_gid).unwrap();
1115        assert!(r4.is_none());
1116
1117        assert!(dbtxn.commit().is_ok());
1118    }
1119
1120    #[tokio::test]
1121    async fn test_cache_db_group_basic() {
1122        sketching::test_init();
1123        let db = Db::new("").expect("failed to create.");
1124        let mut dbtxn = db.write().await;
1125        assert!(dbtxn.migrate().is_ok());
1126
1127        let mut gt1 = GroupToken {
1128            provider: ProviderOrigin::System,
1129            name: "testgroup".to_string(),
1130            spn: "testgroup@example.com".to_string(),
1131            gidnumber: 2000,
1132            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1133            extra_keys: Default::default(),
1134        };
1135
1136        let id_name = Id::Name("testgroup".to_string());
1137        let id_name2 = Id::Name("testgroup2".to_string());
1138        let id_spn = Id::Name("testgroup@example.com".to_string());
1139        let id_spn2 = Id::Name("testgroup2@example.com".to_string());
1140        let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1141        let id_gid = Id::Gid(2000);
1142
1143        // test finding no group
1144        let r1 = dbtxn.get_group(&id_name).unwrap();
1145        assert!(r1.is_none());
1146        let r2 = dbtxn.get_group(&id_spn).unwrap();
1147        assert!(r2.is_none());
1148        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1149        assert!(r3.is_none());
1150        let r4 = dbtxn.get_group(&id_gid).unwrap();
1151        assert!(r4.is_none());
1152
1153        // test adding a group
1154        dbtxn.update_group(&gt1, 0).unwrap();
1155        let r1 = dbtxn.get_group(&id_name).unwrap();
1156        assert!(r1.is_some());
1157        let r2 = dbtxn.get_group(&id_spn).unwrap();
1158        assert!(r2.is_some());
1159        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1160        assert!(r3.is_some());
1161        let r4 = dbtxn.get_group(&id_gid).unwrap();
1162        assert!(r4.is_some());
1163
1164        // add a group via update
1165        gt1.name = "testgroup2".to_string();
1166        gt1.spn = "testgroup2@example.com".to_string();
1167        dbtxn.update_group(&gt1, 0).unwrap();
1168        let r1 = dbtxn.get_group(&id_name).unwrap();
1169        assert!(r1.is_none());
1170        let r2 = dbtxn.get_group(&id_spn).unwrap();
1171        assert!(r2.is_none());
1172        let r1 = dbtxn.get_group(&id_name2).unwrap();
1173        assert!(r1.is_some());
1174        let r2 = dbtxn.get_group(&id_spn2).unwrap();
1175        assert!(r2.is_some());
1176        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1177        assert!(r3.is_some());
1178        let r4 = dbtxn.get_group(&id_gid).unwrap();
1179        assert!(r4.is_some());
1180
1181        // clear cache
1182        assert!(dbtxn.clear().is_ok());
1183
1184        // should be nothing.
1185        let r1 = dbtxn.get_group(&id_name2).unwrap();
1186        assert!(r1.is_none());
1187        let r2 = dbtxn.get_group(&id_spn2).unwrap();
1188        assert!(r2.is_none());
1189        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1190        assert!(r3.is_none());
1191        let r4 = dbtxn.get_group(&id_gid).unwrap();
1192        assert!(r4.is_none());
1193
1194        assert!(dbtxn.commit().is_ok());
1195    }
1196
1197    #[tokio::test]
1198    async fn test_cache_db_account_group_update() {
1199        sketching::test_init();
1200        let db = Db::new("").expect("failed to create.");
1201        let mut dbtxn = db.write().await;
1202        assert!(dbtxn.migrate().is_ok());
1203
1204        let gt1 = GroupToken {
1205            provider: ProviderOrigin::System,
1206            name: "testuser".to_string(),
1207            spn: "testuser@example.com".to_string(),
1208            gidnumber: 2000,
1209            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1210            extra_keys: Default::default(),
1211        };
1212
1213        let gt2 = GroupToken {
1214            provider: ProviderOrigin::System,
1215            name: "testgroup".to_string(),
1216            spn: "testgroup@example.com".to_string(),
1217            gidnumber: 2001,
1218            uuid: uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"),
1219            extra_keys: Default::default(),
1220        };
1221
1222        let mut ut1 = UserToken {
1223            provider: ProviderOrigin::System,
1224            name: "testuser".to_string(),
1225            spn: "testuser@example.com".to_string(),
1226            displayname: "Test User".to_string(),
1227            gidnumber: 2000,
1228            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1229            shell: None,
1230            groups: vec![gt1.clone(), gt2],
1231            sshkeys: vec!["key-a".to_string()],
1232            valid: true,
1233            extra_keys: Default::default(),
1234        };
1235
1236        // First, add the groups.
1237        ut1.groups.iter().for_each(|g| {
1238            dbtxn.update_group(g, 0).unwrap();
1239        });
1240
1241        // The add the account
1242        dbtxn.update_account(&ut1, 0).unwrap();
1243
1244        // Now, get the memberships of the two groups.
1245        let m1 = dbtxn
1246            .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1247            .unwrap();
1248        let m2 = dbtxn
1249            .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1250            .unwrap();
1251        assert_eq!(m1[0].name, "testuser");
1252        assert_eq!(m2[0].name, "testuser");
1253
1254        // Now alter testuser, remove gt2, update.
1255        ut1.groups = vec![gt1];
1256        dbtxn.update_account(&ut1, 0).unwrap();
1257
1258        // Check that the memberships have updated correctly.
1259        let m1 = dbtxn
1260            .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1261            .unwrap();
1262        let m2 = dbtxn
1263            .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1264            .unwrap();
1265        assert_eq!(m1[0].name, "testuser");
1266        assert!(m2.is_empty());
1267
1268        assert!(dbtxn.commit().is_ok());
1269    }
1270
1271    #[tokio::test]
1272    async fn test_cache_db_group_rename_duplicate() {
1273        sketching::test_init();
1274        let db = Db::new("").expect("failed to create.");
1275        let mut dbtxn = db.write().await;
1276        assert!(dbtxn.migrate().is_ok());
1277
1278        let mut gt1 = GroupToken {
1279            provider: ProviderOrigin::System,
1280            name: "testgroup".to_string(),
1281            spn: "testgroup@example.com".to_string(),
1282            gidnumber: 2000,
1283            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1284            extra_keys: Default::default(),
1285        };
1286
1287        let gt2 = GroupToken {
1288            provider: ProviderOrigin::System,
1289            name: "testgroup".to_string(),
1290            spn: "testgroup@example.com".to_string(),
1291            gidnumber: 2001,
1292            uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1293            extra_keys: Default::default(),
1294        };
1295
1296        let id_name = Id::Name("testgroup".to_string());
1297        let id_name2 = Id::Name("testgroup2".to_string());
1298
1299        // test finding no group
1300        let r1 = dbtxn.get_group(&id_name).unwrap();
1301        assert!(r1.is_none());
1302
1303        // test adding a group
1304        dbtxn.update_group(&gt1, 0).unwrap();
1305        let r0 = dbtxn.get_group(&id_name).unwrap();
1306        assert_eq!(
1307            r0.unwrap().0.uuid,
1308            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1309        );
1310
1311        // Do the "rename" of gt1 which is what would allow gt2 to be valid.
1312        gt1.name = "testgroup2".to_string();
1313        gt1.spn = "testgroup2@example.com".to_string();
1314        // Now, add gt2 which dups on gt1 name/spn.
1315        dbtxn.update_group(&gt2, 0).unwrap();
1316        let r2 = dbtxn.get_group(&id_name).unwrap();
1317        assert_eq!(
1318            r2.unwrap().0.uuid,
1319            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1320        );
1321        let r3 = dbtxn.get_group(&id_name2).unwrap();
1322        assert!(r3.is_none());
1323
1324        // Now finally update gt1
1325        dbtxn.update_group(&gt1, 0).unwrap();
1326
1327        // Both now coexist
1328        let r4 = dbtxn.get_group(&id_name).unwrap();
1329        assert_eq!(
1330            r4.unwrap().0.uuid,
1331            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1332        );
1333        let r5 = dbtxn.get_group(&id_name2).unwrap();
1334        assert_eq!(
1335            r5.unwrap().0.uuid,
1336            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1337        );
1338
1339        assert!(dbtxn.commit().is_ok());
1340    }
1341
1342    #[tokio::test]
1343    async fn test_cache_db_account_rename_duplicate() {
1344        sketching::test_init();
1345        let db = Db::new("").expect("failed to create.");
1346        let mut dbtxn = db.write().await;
1347        assert!(dbtxn.migrate().is_ok());
1348
1349        let mut ut1 = UserToken {
1350            provider: ProviderOrigin::System,
1351            name: "testuser".to_string(),
1352            spn: "testuser@example.com".to_string(),
1353            displayname: "Test User".to_string(),
1354            gidnumber: 2000,
1355            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1356            shell: None,
1357            groups: Vec::new(),
1358            sshkeys: vec!["key-a".to_string()],
1359            valid: true,
1360            extra_keys: Default::default(),
1361        };
1362
1363        let ut2 = UserToken {
1364            provider: ProviderOrigin::System,
1365            name: "testuser".to_string(),
1366            spn: "testuser@example.com".to_string(),
1367            displayname: "Test User".to_string(),
1368            gidnumber: 2001,
1369            uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1370            shell: None,
1371            groups: Vec::new(),
1372            sshkeys: vec!["key-a".to_string()],
1373            valid: true,
1374            extra_keys: Default::default(),
1375        };
1376
1377        let id_name = Id::Name("testuser".to_string());
1378        let id_name2 = Id::Name("testuser2".to_string());
1379
1380        // test finding no account
1381        let r1 = dbtxn.get_account(&id_name).unwrap();
1382        assert!(r1.is_none());
1383
1384        // test adding an account
1385        dbtxn.update_account(&ut1, 0).unwrap();
1386        let r0 = dbtxn.get_account(&id_name).unwrap();
1387        assert_eq!(
1388            r0.unwrap().0.uuid,
1389            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1390        );
1391
1392        // Do the "rename" of gt1 which is what would allow gt2 to be valid.
1393        ut1.name = "testuser2".to_string();
1394        ut1.spn = "testuser2@example.com".to_string();
1395        // Now, add gt2 which dups on gt1 name/spn.
1396        dbtxn.update_account(&ut2, 0).unwrap();
1397        let r2 = dbtxn.get_account(&id_name).unwrap();
1398        assert_eq!(
1399            r2.unwrap().0.uuid,
1400            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1401        );
1402        let r3 = dbtxn.get_account(&id_name2).unwrap();
1403        assert!(r3.is_none());
1404
1405        // Now finally update gt1
1406        dbtxn.update_account(&ut1, 0).unwrap();
1407
1408        // Both now coexist
1409        let r4 = dbtxn.get_account(&id_name).unwrap();
1410        assert_eq!(
1411            r4.unwrap().0.uuid,
1412            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1413        );
1414        let r5 = dbtxn.get_account(&id_name2).unwrap();
1415        assert_eq!(
1416            r5.unwrap().0.uuid,
1417            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1418        );
1419
1420        assert!(dbtxn.commit().is_ok());
1421    }
1422}