kanidm_unix_resolver/
db.rs

1use crate::idprovider::interface::{GroupToken, Id, UserToken};
2use async_trait::async_trait;
3use kanidm_hsm_crypto::structures::{LoadableHmacS256Key, LoadableStorageKey};
4use libc::umask;
5use rusqlite::{Connection, OptionalExtension};
6use serde::{de::DeserializeOwned, Serialize};
7use std::convert::TryFrom;
8use std::fmt;
9use tokio::sync::{Mutex, MutexGuard};
10use uuid::Uuid;
11
12const DBV_MAIN: &str = "main";
13// This is in *pages* for sqlite. The default page size is 4096 bytes. So to achieve
14// 32MB we need to divide by this.
15const SQLITE_PAGE_SIZE: i64 = 4096;
16const CACHE_SIZE: i64 = 32_i64.saturating_mul((1024 * 1024) / SQLITE_PAGE_SIZE);
17
18#[async_trait]
19pub trait Cache {
20    type Txn<'db>
21    where
22        Self: 'db;
23
24    async fn write<'db>(&'db self) -> Self::Txn<'db>;
25}
26
27#[async_trait]
28pub trait KeyStore {
29    type Txn<'db>
30    where
31        Self: 'db;
32
33    async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
34}
35
36#[derive(Debug)]
37pub enum CacheError {
38    Cryptography,
39    SerdeJson,
40    Parse,
41    Sqlite,
42    TooManyResults,
43    TransactionInvalidState,
44    Tpm,
45}
46
47pub struct Db {
48    conn: Mutex<Connection>,
49}
50
51pub struct DbTxn<'a> {
52    conn: MutexGuard<'a, Connection>,
53    committed: bool,
54}
55
56pub struct KeyStoreTxn<'a, 'b> {
57    db: &'b mut DbTxn<'a>,
58}
59
60impl<'a, 'b> From<&'b mut DbTxn<'a>> for KeyStoreTxn<'a, 'b> {
61    fn from(db: &'b mut DbTxn<'a>) -> Self {
62        Self { db }
63    }
64}
65
66#[derive(Debug)]
67/// Errors coming back from the `Db` struct
68pub enum DbError {
69    Sqlite,
70    Tpm,
71}
72
73impl Db {
74    pub fn new(path: &str) -> Result<Self, DbError> {
75        let before = unsafe { umask(0o0027) };
76        let conn = Connection::open(path).map_err(|e| {
77            error!(err = ?e, "rusqulite error");
78            DbError::Sqlite
79        })?;
80        let _ = unsafe { umask(before) };
81
82        // Setup WAL/COW mode.
83        conn.pragma_update(None, "journal_mode", "WAL")
84            .map_err(|error| {
85                error!(
86                    "sqlite journal_mode=WAL error: {:?} db_path={:?}",
87                    error, path
88                );
89                DbError::Sqlite
90            })?;
91
92        // synchronous=normal is safe for WAL
93        conn.pragma_update(None, "synchronous", "NORMAL")
94            .map_err(|error| {
95                error!(
96                    "sqlite synchronous=NORMAL error: {:?} db_path={:?}",
97                    error, path
98                );
99                DbError::Sqlite
100            })?;
101
102        conn.pragma_update(None, "cache_size", CACHE_SIZE)
103            .map_err(|error| {
104                error!(
105                    "sqlite cache_size={} error: {:?} db_path={:?}",
106                    CACHE_SIZE, error, path
107                );
108                DbError::Sqlite
109            })?;
110
111        conn.set_prepared_statement_cache_capacity(32);
112
113        Ok(Db {
114            conn: Mutex::new(conn),
115        })
116    }
117}
118
119#[async_trait]
120impl Cache for Db {
121    type Txn<'db> = DbTxn<'db>;
122
123    #[allow(clippy::expect_used)]
124    async fn write<'db>(&'db self) -> Self::Txn<'db> {
125        let conn = self.conn.lock().await;
126        DbTxn::new(conn)
127    }
128}
129
130impl fmt::Debug for Db {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(f, "Db {{}}")
133    }
134}
135
136impl<'a> DbTxn<'a> {
137    fn new(conn: MutexGuard<'a, Connection>) -> Self {
138        // Start the transaction
139        // trace!("Starting db WR txn ...");
140        #[allow(clippy::expect_used)]
141        conn.execute("BEGIN TRANSACTION", [])
142            .expect("Unable to begin transaction!");
143        DbTxn {
144            committed: false,
145            conn,
146        }
147    }
148
149    /// This handles an error coming back from an sqlite event and dumps more information from it
150    fn sqlite_error(&self, msg: &str, error: &rusqlite::Error) -> CacheError {
151        error!(
152            "sqlite {} error: {:?} db_path={:?}",
153            msg,
154            error,
155            &self.conn.path()
156        );
157        CacheError::Sqlite
158    }
159
160    /// This handles an error coming back from an sqlite transaction and dumps a load of information from it
161    fn sqlite_transaction_error(
162        &self,
163        error: &rusqlite::Error,
164        _stmt: &rusqlite::Statement,
165    ) -> CacheError {
166        error!(
167            "sqlite transaction error={:?} db_path={:?}",
168            error,
169            &self.conn.path(),
170        );
171        // TODO: one day figure out if there's an easy way to dump the transaction without the token...
172        CacheError::Sqlite
173    }
174
175    fn get_db_version(&self, key: &str) -> i64 {
176        self.conn
177            .query_row(
178                "SELECT version FROM db_version_t WHERE id = :id",
179                &[(":id", key)],
180                |row| row.get(0),
181            )
182            .unwrap_or({
183                // The value is missing, default to 0.
184                0
185            })
186    }
187
188    fn set_db_version(&self, key: &str, v: i64) -> Result<(), CacheError> {
189        self.conn
190            .execute(
191                "INSERT OR REPLACE INTO db_version_t (id, version) VALUES(:id, :dbv)",
192                named_params! {
193                    ":id": &key,
194                    ":dbv": v,
195                },
196            )
197            .map(|_| ())
198            .map_err(|e| self.sqlite_error("set db_version_t", &e))
199    }
200
201    fn get_account_data_name(
202        &mut self,
203        account_id: &str,
204    ) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
205        let mut stmt = self.conn
206            .prepare(
207        "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
208            )
209            .map_err(|e| {
210                self.sqlite_error("select prepare", &e)
211            })?;
212
213        // Makes tuple (token, expiry)
214        let data_iter = stmt
215            .query_map([account_id], |row| Ok((row.get(0)?, row.get(1)?)))
216            .map_err(|e| self.sqlite_error("query_map failure", &e))?;
217        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
218            .map(|v| v.map_err(|e| self.sqlite_error("map failure", &e)))
219            .collect();
220        data
221    }
222
223    fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
224        let mut stmt = self
225            .conn
226            .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
227            .map_err(|e| self.sqlite_error("select prepare", &e))?;
228
229        // Makes tuple (token, expiry)
230        let data_iter = stmt
231            .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
232            .map_err(|e| self.sqlite_error("query_map", &e))?;
233        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
234            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
235            .collect();
236        data
237    }
238
239    fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
240        let mut stmt = self.conn
241            .prepare(
242                "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
243            )
244            .map_err(|e| {
245                self.sqlite_error("select prepare", &e)
246            })?;
247
248        // Makes tuple (token, expiry)
249        let data_iter = stmt
250            .query_map([grp_id], |row| Ok((row.get(0)?, row.get(1)?)))
251            .map_err(|e| self.sqlite_error("query_map", &e))?;
252        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
253            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
254            .collect();
255        data
256    }
257
258    fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
259        let mut stmt = self
260            .conn
261            .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
262            .map_err(|e| self.sqlite_error("select prepare", &e))?;
263
264        // Makes tuple (token, expiry)
265        let data_iter = stmt
266            .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
267            .map_err(|e| self.sqlite_error("query_map", &e))?;
268        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
269            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
270            .collect();
271        data
272    }
273}
274
275impl KeyStoreTxn<'_, '_> {
276    pub fn get_tagged_hsm_key<K: DeserializeOwned>(
277        &mut self,
278        tag: &str,
279    ) -> Result<Option<K>, CacheError> {
280        self.db.get_tagged_hsm_key(tag)
281    }
282
283    pub fn insert_tagged_hsm_key<K: Serialize>(
284        &mut self,
285        tag: &str,
286        key: &K,
287    ) -> Result<(), CacheError> {
288        self.db.insert_tagged_hsm_key(tag, key)
289    }
290
291    pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
292        self.db.delete_tagged_hsm_key(tag)
293    }
294}
295
296impl DbTxn<'_> {
297    fn get_tagged_hsm_key<K: DeserializeOwned>(
298        &mut self,
299        tag: &str,
300    ) -> Result<Option<K>, CacheError> {
301        let mut stmt = self
302            .conn
303            .prepare("SELECT value FROM hsm_data_t WHERE key = :key")
304            .map_err(|e| self.sqlite_error("select prepare", &e))?;
305
306        let data: Option<Vec<u8>> = stmt
307            .query_row(
308                named_params! {
309                    ":key": tag
310                },
311                |row| row.get(0),
312            )
313            .optional()
314            .map_err(|e| self.sqlite_error("query_row", &e))?;
315
316        match data {
317            Some(d) => Ok(serde_json::from_slice(d.as_slice())
318                .map_err(|e| {
319                    error!("json error -> {:?}", e);
320                })
321                .ok()),
322            None => Ok(None),
323        }
324    }
325
326    fn insert_tagged_hsm_key<K: Serialize>(
327        &mut self,
328        tag: &str,
329        key: &K,
330    ) -> Result<(), CacheError> {
331        let data = serde_json::to_vec(key).map_err(|e| {
332            error!("json error -> {:?}", e);
333            CacheError::SerdeJson
334        })?;
335
336        let mut stmt = self
337            .conn
338            .prepare("INSERT OR REPLACE INTO hsm_data_t (key, value) VALUES (:key, :value)")
339            .map_err(|e| self.sqlite_error("prepare", &e))?;
340
341        stmt.execute(named_params! {
342            ":key": tag,
343            ":value": &data,
344        })
345        .map(|r| {
346            trace!("insert -> {:?}", r);
347        })
348        .map_err(|e| self.sqlite_error("execute", &e))
349    }
350
351    fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
352        self.conn
353            .execute(
354                "DELETE FROM hsm_data_t where key = :key",
355                named_params! {
356                    ":key": tag,
357                },
358            )
359            .map(|_| ())
360            .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))
361    }
362}
363
364impl DbTxn<'_> {
365    pub fn migrate(&mut self) -> Result<(), CacheError> {
366        // This definition can never change.
367        self.conn
368            .execute(
369                "CREATE TABLE IF NOT EXISTS db_version_t (
370                    id TEXT PRIMARY KEY,
371                    version INTEGER
372                )",
373                [],
374            )
375            .map_err(|e| self.sqlite_error("db_version_t create", &e))?;
376
377        let db_version = self.get_db_version(DBV_MAIN);
378
379        if db_version < 1 {
380            // Setup two tables - one for accounts, one for groups.
381            // correctly index the columns.
382            // Optional pw hash field
383            self.conn
384                .execute(
385                    "CREATE TABLE IF NOT EXISTS account_t (
386                    uuid TEXT PRIMARY KEY,
387                    name TEXT NOT NULL UNIQUE,
388                    spn TEXT NOT NULL UNIQUE,
389                    gidnumber INTEGER NOT NULL UNIQUE,
390                    password BLOB,
391                    token BLOB NOT NULL,
392                    expiry NUMERIC NOT NULL
393                )
394                ",
395                    [],
396                )
397                .map_err(|e| self.sqlite_error("account_t create", &e))?;
398
399            self.conn
400                .execute(
401                    "CREATE TABLE IF NOT EXISTS group_t (
402                    uuid TEXT PRIMARY KEY,
403                    name TEXT NOT NULL UNIQUE,
404                    spn TEXT NOT NULL UNIQUE,
405                    gidnumber INTEGER NOT NULL UNIQUE,
406                    token BLOB NOT NULL,
407                    expiry NUMERIC NOT NULL
408                )
409                ",
410                    [],
411                )
412                .map_err(|e| self.sqlite_error("group_t create", &e))?;
413
414            // We defer group foreign keys here because we now manually cascade delete these when
415            // required. This is because insert or replace into will always delete then add
416            // which triggers this. So instead we defer and manually cascade.
417            //
418            // However, on accounts, we CAN delete cascade because accounts will always redefine
419            // their memberships on updates so this is safe to cascade on this direction.
420            self.conn
421                .execute(
422                    "CREATE TABLE IF NOT EXISTS memberof_t (
423                    g_uuid TEXT,
424                    a_uuid TEXT,
425                    FOREIGN KEY(g_uuid) REFERENCES group_t(uuid) DEFERRABLE INITIALLY DEFERRED,
426                    FOREIGN KEY(a_uuid) REFERENCES account_t(uuid) ON DELETE CASCADE
427                )
428                ",
429                    [],
430                )
431                .map_err(|e| self.sqlite_error("memberof_t create error", &e))?;
432
433            // Create the hsm_data store. These are all generally encrypted private
434            // keys, and the hsm structures will decrypt these as required.
435            self.conn
436                .execute(
437                    "CREATE TABLE IF NOT EXISTS hsm_int_t (
438                        key TEXT PRIMARY KEY,
439                        value BLOB NOT NULL
440                    )
441                    ",
442                    [],
443                )
444                .map_err(|e| self.sqlite_error("hsm_int_t create error", &e))?;
445
446            self.conn
447                .execute(
448                    "CREATE TABLE IF NOT EXISTS hsm_data_t (
449                        key TEXT PRIMARY KEY,
450                        value BLOB NOT NULL
451                    )
452                    ",
453                    [],
454                )
455                .map_err(|e| self.sqlite_error("hsm_data_t create error", &e))?;
456
457            // Since this is the 0th migration, we have to reset the HSM here.
458            self.clear_hsm()?;
459        }
460
461        if db_version < 2 {
462            self.conn
463                .execute(
464                    "CREATE INDEX IF NOT EXISTS account_t_uuid_idx ON account_t ( uuid )",
465                    [],
466                )
467                .map_err(|e| self.sqlite_error("account_t uuid index create", &e))?;
468
469            self.conn
470                .execute(
471                    "CREATE INDEX IF NOT EXISTS account_t_name_idx ON account_t ( name )",
472                    [],
473                )
474                .map_err(|e| self.sqlite_error("account_t name index create", &e))?;
475
476            self.conn
477                .execute(
478                    "CREATE INDEX IF NOT EXISTS account_t_spn_idx ON account_t ( spn )",
479                    [],
480                )
481                .map_err(|e| self.sqlite_error("account_t spn index create", &e))?;
482
483            self.conn
484                .execute(
485                    "CREATE INDEX IF NOT EXISTS account_t_gidnumber_idx ON account_t ( gidnumber )",
486                    [],
487                )
488                .map_err(|e| self.sqlite_error("account_t gidnumber index create", &e))?;
489
490            self.conn
491                .execute(
492                    "CREATE INDEX IF NOT EXISTS group_t_uuid_idx ON group_t ( uuid )",
493                    [],
494                )
495                .map_err(|e| self.sqlite_error("group_t uuid index create", &e))?;
496
497            self.conn
498                .execute(
499                    "CREATE INDEX IF NOT EXISTS group_t_name_idx ON group_t ( name )",
500                    [],
501                )
502                .map_err(|e| self.sqlite_error("group_t name index create", &e))?;
503
504            self.conn
505                .execute(
506                    "CREATE INDEX IF NOT EXISTS group_t_spn_idx ON group_t ( spn )",
507                    [],
508                )
509                .map_err(|e| self.sqlite_error("group_t spn index create", &e))?;
510
511            self.conn
512                .execute(
513                    "CREATE INDEX IF NOT EXISTS group_t_gidnumber_idx ON group_t ( gidnumber )",
514                    [],
515                )
516                .map_err(|e| self.sqlite_error("group_t gidnumber index create", &e))?;
517
518            self.conn
519                .execute(
520                    "CREATE INDEX IF NOT EXISTS memberof_t_g_uuid_idx ON memberof_t ( g_uuid )",
521                    [],
522                )
523                .map_err(|e| self.sqlite_error("memberof_t g_uuid index create", &e))?;
524
525            self.conn
526                .execute(
527                    "CREATE INDEX IF NOT EXISTS memberof_t_a_uuid_idx ON memberof_t ( a_uuid )",
528                    [],
529                )
530                .map_err(|e| self.sqlite_error("memberof_t a_uuid index create", &e))?;
531        }
532
533        self.set_db_version(DBV_MAIN, 2)?;
534
535        Ok(())
536    }
537
538    #[instrument(level = "debug", skip_all)]
539    pub fn commit(mut self) -> Result<(), CacheError> {
540        if self.committed {
541            error!("Invalid state, SQL transaction was already committed!");
542            return Err(CacheError::TransactionInvalidState);
543        }
544        self.committed = true;
545
546        self.conn
547            .execute("COMMIT TRANSACTION", [])
548            .map(|_| ())
549            .map_err(|e| self.sqlite_error("commit", &e))
550    }
551
552    pub fn invalidate(&mut self) -> Result<(), CacheError> {
553        self.conn
554            .execute("UPDATE group_t SET expiry = 0", [])
555            .map_err(|e| self.sqlite_error("update group_t", &e))?;
556
557        self.conn
558            .execute("UPDATE account_t SET expiry = 0", [])
559            .map_err(|e| self.sqlite_error("update account_t", &e))?;
560
561        Ok(())
562    }
563
564    pub fn clear(&mut self) -> Result<(), CacheError> {
565        self.conn
566            .execute("DELETE FROM memberof_t", [])
567            .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
568
569        self.conn
570            .execute("DELETE FROM group_t", [])
571            .map_err(|e| self.sqlite_error("delete group_t", &e))?;
572
573        self.conn
574            .execute("DELETE FROM account_t", [])
575            .map_err(|e| self.sqlite_error("delete group_t", &e))?;
576
577        Ok(())
578    }
579
580    pub fn clear_hsm(&mut self) -> Result<(), CacheError> {
581        self.clear()?;
582
583        self.conn
584            .execute("DELETE FROM hsm_int_t", [])
585            .map_err(|e| self.sqlite_error("delete hsm_int_t", &e))?;
586
587        self.conn
588            .execute("DELETE FROM hsm_data_t", [])
589            .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))?;
590
591        Ok(())
592    }
593
594    pub fn get_hsm_root_storage_key(&mut self) -> Result<Option<LoadableStorageKey>, CacheError> {
595        let mut stmt = self
596            .conn
597            .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
598            .map_err(|e| self.sqlite_error("select prepare", &e))?;
599
600        let data: Option<Vec<u8>> = stmt
601            .query_row([], |row| row.get(0))
602            .optional()
603            .map_err(|e| self.sqlite_error("query_row", &e))?;
604
605        match data {
606            Some(d) => Ok(serde_json::from_slice(d.as_slice())
607                .map_err(|e| {
608                    error!("json error -> {:?}", e);
609                })
610                .ok()),
611            None => Ok(None),
612        }
613    }
614
615    pub fn insert_hsm_root_storage_key(
616        &mut self,
617        machine_key: &LoadableStorageKey,
618    ) -> Result<(), CacheError> {
619        let data = serde_json::to_vec(machine_key).map_err(|e| {
620            error!("insert_hsm_machine_key json error -> {:?}", e);
621            CacheError::SerdeJson
622        })?;
623
624        let mut stmt = self
625            .conn
626            .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
627            .map_err(|e| self.sqlite_error("prepare", &e))?;
628
629        stmt.execute(named_params! {
630            ":key": "mk",
631            ":value": &data,
632        })
633        .map(|r| {
634            trace!("insert -> {:?}", r);
635        })
636        .map_err(|e| self.sqlite_error("execute", &e))
637    }
638
639    pub fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacS256Key>, CacheError> {
640        let mut stmt = self
641            .conn
642            .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
643            .map_err(|e| self.sqlite_error("select prepare", &e))?;
644
645        let data: Option<Vec<u8>> = stmt
646            .query_row([], |row| row.get(0))
647            .optional()
648            .map_err(|e| self.sqlite_error("query_row", &e))?;
649
650        match data {
651            Some(d) => Ok(serde_json::from_slice(d.as_slice())
652                .map_err(|e| {
653                    error!("json error -> {:?}", e);
654                })
655                .ok()),
656            None => Ok(None),
657        }
658    }
659
660    pub fn insert_hsm_hmac_key(
661        &mut self,
662        hmac_key: &LoadableHmacS256Key,
663    ) -> Result<(), CacheError> {
664        let data = serde_json::to_vec(hmac_key).map_err(|e| {
665            error!("insert_hsm_hmac_key json error -> {:?}", e);
666            CacheError::SerdeJson
667        })?;
668
669        let mut stmt = self
670            .conn
671            .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
672            .map_err(|e| self.sqlite_error("prepare", &e))?;
673
674        stmt.execute(named_params! {
675            ":key": "hmac",
676            ":value": &data,
677        })
678        .map(|r| {
679            trace!("insert -> {:?}", r);
680        })
681        .map_err(|e| self.sqlite_error("execute", &e))
682    }
683
684    #[instrument(level = "debug", skip_all)]
685    pub fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
686        let data = match account_id {
687            Id::Name(n) => self.get_account_data_name(n.as_str()),
688            Id::Gid(g) => self.get_account_data_gid(*g),
689        }?;
690
691        // Assert only one result?
692        if data.len() >= 2 {
693            error!("invalid db state, multiple entries matched query?");
694            return Err(CacheError::TooManyResults);
695        }
696
697        if let Some((token, expiry)) = data.first() {
698            // token convert with json.
699            // If this errors, we specifically return Ok(None) because that triggers
700            // the cache to refetch the token.
701            match serde_json::from_slice(token.as_slice()) {
702                Ok(t) => {
703                    let e = u64::try_from(*expiry).map_err(|e| {
704                        error!("u64 convert error -> {:?}", e);
705                        CacheError::Parse
706                    })?;
707                    Ok(Some((t, e)))
708                }
709                Err(e) => {
710                    warn!("recoverable - json error -> {:?}", e);
711                    Ok(None)
712                }
713            }
714        } else {
715            Ok(None)
716        }
717    }
718
719    #[instrument(level = "debug", skip_all)]
720    pub fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
721        let mut stmt = self
722            .conn
723            .prepare("SELECT token FROM account_t")
724            .map_err(|e| self.sqlite_error("select prepare", &e))?;
725
726        let data_iter = stmt
727            .query_map([], |row| row.get(0))
728            .map_err(|e| self.sqlite_error("query_map", &e))?;
729        let data: Result<Vec<Vec<u8>>, _> = data_iter
730            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
731            .collect();
732
733        let data = data?;
734
735        Ok(data
736            .iter()
737            // We filter map here so that anything invalid is skipped.
738            .filter_map(|token| {
739                // token convert with json.
740                serde_json::from_slice(token.as_slice())
741                    .map_err(|e| {
742                        warn!("get_accounts json error -> {:?}", e);
743                    })
744                    .ok()
745            })
746            .collect())
747    }
748
749    #[instrument(level = "debug", skip_all)]
750    pub fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
751        let data = serde_json::to_vec(account).map_err(|e| {
752            error!("update_account json error -> {:?}", e);
753            CacheError::SerdeJson
754        })?;
755        let expire = i64::try_from(expire).map_err(|e| {
756            error!("update_account i64 conversion error -> {:?}", e);
757            CacheError::Parse
758        })?;
759
760        // This is needed because sqlites 'insert or replace into', will null the password field
761        // if present, and upsert MUST match the exact conflicting column, so that means we have
762        // to manually manage the update or insert :( :(
763        let account_uuid = account.uuid.as_hyphenated().to_string();
764
765        // Find anything conflicting and purge it.
766        self.conn.execute("DELETE FROM account_t WHERE NOT uuid = :uuid AND (name = :name OR spn = :spn OR gidnumber = :gidnumber)",
767            named_params!{
768                ":uuid": &account_uuid,
769                ":name": &account.name,
770                ":spn": &account.spn,
771                ":gidnumber": &account.gidnumber,
772            }
773            )
774            .map_err(|e| {
775                self.sqlite_error("delete account_t duplicate", &e)
776            })
777            .map(|_| ())?;
778
779        let updated = self.conn.execute(
780                "UPDATE account_t SET name=:name, spn=:spn, gidnumber=:gidnumber, token=:token, expiry=:expiry WHERE uuid = :uuid",
781            named_params!{
782                ":uuid": &account_uuid,
783                ":name": &account.name,
784                ":spn": &account.spn,
785                ":gidnumber": &account.gidnumber,
786                ":token": &data,
787                ":expiry": &expire,
788            }
789            )
790            .map_err(|e| {
791                self.sqlite_error("delete account_t duplicate", &e)
792            })?;
793
794        if updated == 0 {
795            let mut stmt = self.conn
796                .prepare("INSERT INTO account_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry) ON CONFLICT(uuid) DO UPDATE SET name=excluded.name, spn=excluded.name, gidnumber=excluded.gidnumber, token=excluded.token, expiry=excluded.expiry")
797                .map_err(|e| {
798                    self.sqlite_error("prepare", &e)
799                })?;
800
801            stmt.execute(named_params! {
802                ":uuid": &account_uuid,
803                ":name": &account.name,
804                ":spn": &account.spn,
805                ":gidnumber": &account.gidnumber,
806                ":token": &data,
807                ":expiry": &expire,
808            })
809            .map(|r| {
810                trace!("insert -> {:?}", r);
811            })
812            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
813        }
814
815        // Now, we have to update the group memberships.
816
817        // First remove everything that already exists:
818        let mut stmt = self
819            .conn
820            .prepare("DELETE FROM memberof_t WHERE a_uuid = :a_uuid")
821            .map_err(|e| self.sqlite_error("prepare", &e))?;
822
823        stmt.execute([&account_uuid])
824            .map(|r| {
825                trace!("delete memberships -> {:?}", r);
826            })
827            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
828
829        let mut stmt = self
830            .conn
831            .prepare("INSERT INTO memberof_t (a_uuid, g_uuid) VALUES (:a_uuid, :g_uuid)")
832            .map_err(|e| self.sqlite_error("prepare", &e))?;
833        // Now for each group, add the relation.
834        account.groups.iter().try_for_each(|g| {
835            stmt.execute(named_params! {
836                ":a_uuid": &account_uuid,
837                ":g_uuid": &g.uuid.as_hyphenated().to_string(),
838            })
839            .map(|r| {
840                trace!("insert membership -> {:?}", r);
841            })
842            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))
843        })
844    }
845
846    #[instrument(level = "debug", skip_all)]
847    pub fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
848        let account_uuid = a_uuid.as_hyphenated().to_string();
849
850        self.conn
851            .execute(
852                "DELETE FROM memberof_t WHERE a_uuid = :a_uuid",
853                params![&account_uuid],
854            )
855            .map(|_| ())
856            .map_err(|e| self.sqlite_error("account_t memberof_t cascade delete", &e))?;
857
858        self.conn
859            .execute(
860                "DELETE FROM account_t WHERE uuid = :a_uuid",
861                params![&account_uuid],
862            )
863            .map(|_| ())
864            .map_err(|e| self.sqlite_error("account_t delete", &e))
865    }
866
867    #[instrument(level = "debug", skip_all)]
868    pub fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
869        let data = match grp_id {
870            Id::Name(n) => self.get_group_data_name(n.as_str()),
871            Id::Gid(g) => self.get_group_data_gid(*g),
872        }?;
873
874        // Assert only one result?
875        if data.len() >= 2 {
876            error!("invalid db state, multiple entries matched query?");
877            return Err(CacheError::TooManyResults);
878        }
879
880        if let Some((token, expiry)) = data.first() {
881            // token convert with json.
882            // If this errors, we specifically return Ok(None) because that triggers
883            // the cache to refetch the token.
884            match serde_json::from_slice(token.as_slice()) {
885                Ok(t) => {
886                    let e = u64::try_from(*expiry).map_err(|e| {
887                        error!("u64 convert error -> {:?}", e);
888                        CacheError::Parse
889                    })?;
890                    Ok(Some((t, e)))
891                }
892                Err(e) => {
893                    warn!("recoverable - json error -> {:?}", e);
894                    Ok(None)
895                }
896            }
897        } else {
898            Ok(None)
899        }
900    }
901
902    #[instrument(level = "debug", skip_all)]
903    pub fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
904        let mut stmt = self
905            .conn
906            .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
907            .map_err(|e| {
908                self.sqlite_error("select prepare", &e)
909            })?;
910
911        let data_iter = stmt
912            .query_map([g_uuid.as_hyphenated().to_string()], |row| row.get(0))
913            .map_err(|e| self.sqlite_error("query_map", &e))?;
914        let data: Result<Vec<Vec<u8>>, _> = data_iter
915            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
916            .collect();
917
918        let data = data?;
919
920        data.iter()
921            .map(|token| {
922                // token convert with json.
923                // trace!("{:?}", token);
924                serde_json::from_slice(token.as_slice()).map_err(|e| {
925                    error!("json error -> {:?}", e);
926                    CacheError::SerdeJson
927                })
928            })
929            .collect()
930    }
931
932    #[instrument(level = "debug", skip_all)]
933    pub fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
934        let mut stmt = self
935            .conn
936            .prepare("SELECT token FROM group_t")
937            .map_err(|e| self.sqlite_error("select prepare", &e))?;
938
939        let data_iter = stmt
940            .query_map([], |row| row.get(0))
941            .map_err(|e| self.sqlite_error("query_map", &e))?;
942        let data: Result<Vec<Vec<u8>>, _> = data_iter
943            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
944            .collect();
945
946        let data = data?;
947
948        Ok(data
949            .iter()
950            .filter_map(|token| {
951                // token convert with json.
952                // trace!("{:?}", token);
953                serde_json::from_slice(token.as_slice())
954                    .map_err(|e| {
955                        error!("json error -> {:?}", e);
956                    })
957                    .ok()
958            })
959            .collect())
960    }
961
962    #[instrument(level = "debug", skip_all)]
963    pub fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
964        let data = serde_json::to_vec(grp).map_err(|e| {
965            error!("json error -> {:?}", e);
966            CacheError::SerdeJson
967        })?;
968        let expire = i64::try_from(expire).map_err(|e| {
969            error!("i64 convert error -> {:?}", e);
970            CacheError::Parse
971        })?;
972
973        let mut stmt = self.conn
974            .prepare("INSERT OR REPLACE INTO group_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry)")
975            .map_err(|e| {
976                self.sqlite_error("prepare", &e)
977            })?;
978
979        // We have to to-str uuid as the sqlite impl makes it a blob which breaks our selects in get.
980        stmt.execute(named_params! {
981            ":uuid": &grp.uuid.as_hyphenated().to_string(),
982            ":name": &grp.name,
983            ":spn": &grp.spn,
984            ":gidnumber": &grp.gidnumber,
985            ":token": &data,
986            ":expiry": &expire,
987        })
988        .map(|r| {
989            trace!("insert -> {:?}", r);
990        })
991        .map_err(|e| self.sqlite_error("execute", &e))
992    }
993
994    #[instrument(level = "debug", skip_all)]
995    pub fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
996        let group_uuid = g_uuid.as_hyphenated().to_string();
997        self.conn
998            .execute(
999                "DELETE FROM memberof_t WHERE g_uuid = :g_uuid",
1000                [&group_uuid],
1001            )
1002            .map(|_| ())
1003            .map_err(|e| self.sqlite_error("group_t memberof_t cascade delete", &e))?;
1004        self.conn
1005            .execute("DELETE FROM group_t WHERE uuid = :g_uuid", [&group_uuid])
1006            .map(|_| ())
1007            .map_err(|e| self.sqlite_error("group_t delete", &e))
1008    }
1009}
1010
1011impl fmt::Debug for DbTxn<'_> {
1012    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1013        write!(f, "DbTxn {{}}")
1014    }
1015}
1016
1017impl Drop for DbTxn<'_> {
1018    // Abort
1019    fn drop(&mut self) {
1020        if !self.committed {
1021            // trace!("Aborting BE WR txn");
1022            #[allow(clippy::expect_used)]
1023            self.conn
1024                .execute("ROLLBACK TRANSACTION", [])
1025                .expect("Unable to rollback transaction! Can not proceed!!!");
1026        }
1027    }
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032    use super::{Cache, Db};
1033    use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
1034
1035    #[tokio::test]
1036    async fn test_cache_db_account_basic() {
1037        sketching::test_init();
1038        let db = Db::new("").expect("failed to create.");
1039        let mut dbtxn = db.write().await;
1040        assert!(dbtxn.migrate().is_ok());
1041
1042        let mut ut1 = UserToken {
1043            provider: ProviderOrigin::System,
1044            name: "testuser".to_string(),
1045            spn: "testuser@example.com".to_string(),
1046            displayname: "Test User".to_string(),
1047            gidnumber: 2000,
1048            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1049            shell: None,
1050            groups: Vec::new(),
1051            sshkeys: vec!["key-a".to_string()],
1052            valid: true,
1053            extra_keys: Default::default(),
1054        };
1055
1056        let id_name = Id::Name("testuser".to_string());
1057        let id_name2 = Id::Name("testuser2".to_string());
1058        let id_spn = Id::Name("testuser@example.com".to_string());
1059        let id_spn2 = Id::Name("testuser2@example.com".to_string());
1060        let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1061        let id_gid = Id::Gid(2000);
1062
1063        // test finding no account
1064        let r1 = dbtxn.get_account(&id_name).unwrap();
1065        assert!(r1.is_none());
1066        let r2 = dbtxn.get_account(&id_spn).unwrap();
1067        assert!(r2.is_none());
1068        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1069        assert!(r3.is_none());
1070        let r4 = dbtxn.get_account(&id_gid).unwrap();
1071        assert!(r4.is_none());
1072
1073        // test adding an account
1074        dbtxn.update_account(&ut1, 0).unwrap();
1075
1076        // test we can get it.
1077        let r1 = dbtxn.get_account(&id_name).unwrap();
1078        assert!(r1.is_some());
1079        let r2 = dbtxn.get_account(&id_spn).unwrap();
1080        assert!(r2.is_some());
1081        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1082        assert!(r3.is_some());
1083        let r4 = dbtxn.get_account(&id_gid).unwrap();
1084        assert!(r4.is_some());
1085
1086        // test adding an account that was renamed
1087        ut1.name = "testuser2".to_string();
1088        ut1.spn = "testuser2@example.com".to_string();
1089        dbtxn.update_account(&ut1, 0).unwrap();
1090
1091        // get the account
1092        let r1 = dbtxn.get_account(&id_name).unwrap();
1093        assert!(r1.is_none());
1094        let r2 = dbtxn.get_account(&id_spn).unwrap();
1095        assert!(r2.is_none());
1096        let r1 = dbtxn.get_account(&id_name2).unwrap();
1097        assert!(r1.is_some());
1098        let r2 = dbtxn.get_account(&id_spn2).unwrap();
1099        assert!(r2.is_some());
1100        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1101        assert!(r3.is_some());
1102        let r4 = dbtxn.get_account(&id_gid).unwrap();
1103        assert!(r4.is_some());
1104
1105        // Clear cache
1106        assert!(dbtxn.clear().is_ok());
1107
1108        // should be nothing
1109        let r1 = dbtxn.get_account(&id_name2).unwrap();
1110        assert!(r1.is_none());
1111        let r2 = dbtxn.get_account(&id_spn2).unwrap();
1112        assert!(r2.is_none());
1113        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1114        assert!(r3.is_none());
1115        let r4 = dbtxn.get_account(&id_gid).unwrap();
1116        assert!(r4.is_none());
1117
1118        assert!(dbtxn.commit().is_ok());
1119    }
1120
1121    #[tokio::test]
1122    async fn test_cache_db_group_basic() {
1123        sketching::test_init();
1124        let db = Db::new("").expect("failed to create.");
1125        let mut dbtxn = db.write().await;
1126        assert!(dbtxn.migrate().is_ok());
1127
1128        let mut gt1 = GroupToken {
1129            provider: ProviderOrigin::System,
1130            name: "testgroup".to_string(),
1131            spn: "testgroup@example.com".to_string(),
1132            gidnumber: 2000,
1133            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1134            extra_keys: Default::default(),
1135        };
1136
1137        let id_name = Id::Name("testgroup".to_string());
1138        let id_name2 = Id::Name("testgroup2".to_string());
1139        let id_spn = Id::Name("testgroup@example.com".to_string());
1140        let id_spn2 = Id::Name("testgroup2@example.com".to_string());
1141        let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1142        let id_gid = Id::Gid(2000);
1143
1144        // test finding no group
1145        let r1 = dbtxn.get_group(&id_name).unwrap();
1146        assert!(r1.is_none());
1147        let r2 = dbtxn.get_group(&id_spn).unwrap();
1148        assert!(r2.is_none());
1149        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1150        assert!(r3.is_none());
1151        let r4 = dbtxn.get_group(&id_gid).unwrap();
1152        assert!(r4.is_none());
1153
1154        // test adding a group
1155        dbtxn.update_group(&gt1, 0).unwrap();
1156        let r1 = dbtxn.get_group(&id_name).unwrap();
1157        assert!(r1.is_some());
1158        let r2 = dbtxn.get_group(&id_spn).unwrap();
1159        assert!(r2.is_some());
1160        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1161        assert!(r3.is_some());
1162        let r4 = dbtxn.get_group(&id_gid).unwrap();
1163        assert!(r4.is_some());
1164
1165        // add a group via update
1166        gt1.name = "testgroup2".to_string();
1167        gt1.spn = "testgroup2@example.com".to_string();
1168        dbtxn.update_group(&gt1, 0).unwrap();
1169        let r1 = dbtxn.get_group(&id_name).unwrap();
1170        assert!(r1.is_none());
1171        let r2 = dbtxn.get_group(&id_spn).unwrap();
1172        assert!(r2.is_none());
1173        let r1 = dbtxn.get_group(&id_name2).unwrap();
1174        assert!(r1.is_some());
1175        let r2 = dbtxn.get_group(&id_spn2).unwrap();
1176        assert!(r2.is_some());
1177        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1178        assert!(r3.is_some());
1179        let r4 = dbtxn.get_group(&id_gid).unwrap();
1180        assert!(r4.is_some());
1181
1182        // clear cache
1183        assert!(dbtxn.clear().is_ok());
1184
1185        // should be nothing.
1186        let r1 = dbtxn.get_group(&id_name2).unwrap();
1187        assert!(r1.is_none());
1188        let r2 = dbtxn.get_group(&id_spn2).unwrap();
1189        assert!(r2.is_none());
1190        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1191        assert!(r3.is_none());
1192        let r4 = dbtxn.get_group(&id_gid).unwrap();
1193        assert!(r4.is_none());
1194
1195        assert!(dbtxn.commit().is_ok());
1196    }
1197
1198    #[tokio::test]
1199    async fn test_cache_db_account_group_update() {
1200        sketching::test_init();
1201        let db = Db::new("").expect("failed to create.");
1202        let mut dbtxn = db.write().await;
1203        assert!(dbtxn.migrate().is_ok());
1204
1205        let gt1 = GroupToken {
1206            provider: ProviderOrigin::System,
1207            name: "testuser".to_string(),
1208            spn: "testuser@example.com".to_string(),
1209            gidnumber: 2000,
1210            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1211            extra_keys: Default::default(),
1212        };
1213
1214        let gt2 = GroupToken {
1215            provider: ProviderOrigin::System,
1216            name: "testgroup".to_string(),
1217            spn: "testgroup@example.com".to_string(),
1218            gidnumber: 2001,
1219            uuid: uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"),
1220            extra_keys: Default::default(),
1221        };
1222
1223        let mut ut1 = UserToken {
1224            provider: ProviderOrigin::System,
1225            name: "testuser".to_string(),
1226            spn: "testuser@example.com".to_string(),
1227            displayname: "Test User".to_string(),
1228            gidnumber: 2000,
1229            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1230            shell: None,
1231            groups: vec![gt1.clone(), gt2],
1232            sshkeys: vec!["key-a".to_string()],
1233            valid: true,
1234            extra_keys: Default::default(),
1235        };
1236
1237        // First, add the groups.
1238        ut1.groups.iter().for_each(|g| {
1239            dbtxn.update_group(g, 0).unwrap();
1240        });
1241
1242        // The add the account
1243        dbtxn.update_account(&ut1, 0).unwrap();
1244
1245        // Now, get the memberships of the two groups.
1246        let m1 = dbtxn
1247            .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1248            .unwrap();
1249        let m2 = dbtxn
1250            .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1251            .unwrap();
1252        assert_eq!(m1[0].name, "testuser");
1253        assert_eq!(m2[0].name, "testuser");
1254
1255        // Now alter testuser, remove gt2, update.
1256        ut1.groups = vec![gt1];
1257        dbtxn.update_account(&ut1, 0).unwrap();
1258
1259        // Check that the memberships have updated correctly.
1260        let m1 = dbtxn
1261            .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1262            .unwrap();
1263        let m2 = dbtxn
1264            .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1265            .unwrap();
1266        assert_eq!(m1[0].name, "testuser");
1267        assert!(m2.is_empty());
1268
1269        assert!(dbtxn.commit().is_ok());
1270    }
1271
1272    #[tokio::test]
1273    async fn test_cache_db_group_rename_duplicate() {
1274        sketching::test_init();
1275        let db = Db::new("").expect("failed to create.");
1276        let mut dbtxn = db.write().await;
1277        assert!(dbtxn.migrate().is_ok());
1278
1279        let mut gt1 = GroupToken {
1280            provider: ProviderOrigin::System,
1281            name: "testgroup".to_string(),
1282            spn: "testgroup@example.com".to_string(),
1283            gidnumber: 2000,
1284            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1285            extra_keys: Default::default(),
1286        };
1287
1288        let gt2 = GroupToken {
1289            provider: ProviderOrigin::System,
1290            name: "testgroup".to_string(),
1291            spn: "testgroup@example.com".to_string(),
1292            gidnumber: 2001,
1293            uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1294            extra_keys: Default::default(),
1295        };
1296
1297        let id_name = Id::Name("testgroup".to_string());
1298        let id_name2 = Id::Name("testgroup2".to_string());
1299
1300        // test finding no group
1301        let r1 = dbtxn.get_group(&id_name).unwrap();
1302        assert!(r1.is_none());
1303
1304        // test adding a group
1305        dbtxn.update_group(&gt1, 0).unwrap();
1306        let r0 = dbtxn.get_group(&id_name).unwrap();
1307        assert_eq!(
1308            r0.unwrap().0.uuid,
1309            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1310        );
1311
1312        // Do the "rename" of gt1 which is what would allow gt2 to be valid.
1313        gt1.name = "testgroup2".to_string();
1314        gt1.spn = "testgroup2@example.com".to_string();
1315        // Now, add gt2 which dups on gt1 name/spn.
1316        dbtxn.update_group(&gt2, 0).unwrap();
1317        let r2 = dbtxn.get_group(&id_name).unwrap();
1318        assert_eq!(
1319            r2.unwrap().0.uuid,
1320            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1321        );
1322        let r3 = dbtxn.get_group(&id_name2).unwrap();
1323        assert!(r3.is_none());
1324
1325        // Now finally update gt1
1326        dbtxn.update_group(&gt1, 0).unwrap();
1327
1328        // Both now coexist
1329        let r4 = dbtxn.get_group(&id_name).unwrap();
1330        assert_eq!(
1331            r4.unwrap().0.uuid,
1332            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1333        );
1334        let r5 = dbtxn.get_group(&id_name2).unwrap();
1335        assert_eq!(
1336            r5.unwrap().0.uuid,
1337            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1338        );
1339
1340        assert!(dbtxn.commit().is_ok());
1341    }
1342
1343    #[tokio::test]
1344    async fn test_cache_db_account_rename_duplicate() {
1345        sketching::test_init();
1346        let db = Db::new("").expect("failed to create.");
1347        let mut dbtxn = db.write().await;
1348        assert!(dbtxn.migrate().is_ok());
1349
1350        let mut ut1 = UserToken {
1351            provider: ProviderOrigin::System,
1352            name: "testuser".to_string(),
1353            spn: "testuser@example.com".to_string(),
1354            displayname: "Test User".to_string(),
1355            gidnumber: 2000,
1356            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1357            shell: None,
1358            groups: Vec::new(),
1359            sshkeys: vec!["key-a".to_string()],
1360            valid: true,
1361            extra_keys: Default::default(),
1362        };
1363
1364        let ut2 = UserToken {
1365            provider: ProviderOrigin::System,
1366            name: "testuser".to_string(),
1367            spn: "testuser@example.com".to_string(),
1368            displayname: "Test User".to_string(),
1369            gidnumber: 2001,
1370            uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1371            shell: None,
1372            groups: Vec::new(),
1373            sshkeys: vec!["key-a".to_string()],
1374            valid: true,
1375            extra_keys: Default::default(),
1376        };
1377
1378        let id_name = Id::Name("testuser".to_string());
1379        let id_name2 = Id::Name("testuser2".to_string());
1380
1381        // test finding no account
1382        let r1 = dbtxn.get_account(&id_name).unwrap();
1383        assert!(r1.is_none());
1384
1385        // test adding an account
1386        dbtxn.update_account(&ut1, 0).unwrap();
1387        let r0 = dbtxn.get_account(&id_name).unwrap();
1388        assert_eq!(
1389            r0.unwrap().0.uuid,
1390            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1391        );
1392
1393        // Do the "rename" of gt1 which is what would allow gt2 to be valid.
1394        ut1.name = "testuser2".to_string();
1395        ut1.spn = "testuser2@example.com".to_string();
1396        // Now, add gt2 which dups on gt1 name/spn.
1397        dbtxn.update_account(&ut2, 0).unwrap();
1398        let r2 = dbtxn.get_account(&id_name).unwrap();
1399        assert_eq!(
1400            r2.unwrap().0.uuid,
1401            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1402        );
1403        let r3 = dbtxn.get_account(&id_name2).unwrap();
1404        assert!(r3.is_none());
1405
1406        // Now finally update gt1
1407        dbtxn.update_account(&ut1, 0).unwrap();
1408
1409        // Both now coexist
1410        let r4 = dbtxn.get_account(&id_name).unwrap();
1411        assert_eq!(
1412            r4.unwrap().0.uuid,
1413            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1414        );
1415        let r5 = dbtxn.get_account(&id_name2).unwrap();
1416        assert_eq!(
1417            r5.unwrap().0.uuid,
1418            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1419        );
1420
1421        assert!(dbtxn.commit().is_ok());
1422    }
1423}