kanidm_unix_resolver/
db.rs

1use std::convert::TryFrom;
2use std::fmt;
3
4use crate::idprovider::interface::{GroupToken, Id, UserToken};
5use async_trait::async_trait;
6use libc::umask;
7use rusqlite::{Connection, OptionalExtension};
8use tokio::sync::{Mutex, MutexGuard};
9use uuid::Uuid;
10
11use serde::{de::DeserializeOwned, Serialize};
12
13use kanidm_hsm_crypto::{LoadableHmacKey, LoadableMachineKey};
14
15const DBV_MAIN: &str = "main";
16
17#[async_trait]
18pub trait Cache {
19    type Txn<'db>
20    where
21        Self: 'db;
22
23    async fn write<'db>(&'db self) -> Self::Txn<'db>;
24}
25
26#[async_trait]
27pub trait KeyStore {
28    type Txn<'db>
29    where
30        Self: 'db;
31
32    async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
33}
34
35#[derive(Debug)]
36pub enum CacheError {
37    Cryptography,
38    SerdeJson,
39    Parse,
40    Sqlite,
41    TooManyResults,
42    TransactionInvalidState,
43    Tpm,
44}
45
46pub struct Db {
47    conn: Mutex<Connection>,
48}
49
50pub struct DbTxn<'a> {
51    conn: MutexGuard<'a, Connection>,
52    committed: bool,
53}
54
55pub struct KeyStoreTxn<'a, 'b> {
56    db: &'b mut DbTxn<'a>,
57}
58
59impl<'a, 'b> From<&'b mut DbTxn<'a>> for KeyStoreTxn<'a, 'b> {
60    fn from(db: &'b mut DbTxn<'a>) -> Self {
61        Self { db }
62    }
63}
64
65#[derive(Debug)]
66/// Errors coming back from the `Db` struct
67pub enum DbError {
68    Sqlite,
69    Tpm,
70}
71
72impl Db {
73    pub fn new(path: &str) -> Result<Self, DbError> {
74        let before = unsafe { umask(0o0027) };
75        let conn = Connection::open(path).map_err(|e| {
76            error!(err = ?e, "rusqulite error");
77            DbError::Sqlite
78        })?;
79        let _ = unsafe { umask(before) };
80
81        Ok(Db {
82            conn: Mutex::new(conn),
83        })
84    }
85}
86
87#[async_trait]
88impl Cache for Db {
89    type Txn<'db> = DbTxn<'db>;
90
91    #[allow(clippy::expect_used)]
92    async fn write<'db>(&'db self) -> Self::Txn<'db> {
93        let conn = self.conn.lock().await;
94        DbTxn::new(conn)
95    }
96}
97
98impl fmt::Debug for Db {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        write!(f, "Db {{}}")
101    }
102}
103
104impl<'a> DbTxn<'a> {
105    fn new(conn: MutexGuard<'a, Connection>) -> Self {
106        // Start the transaction
107        // trace!("Starting db WR txn ...");
108        #[allow(clippy::expect_used)]
109        conn.execute("BEGIN TRANSACTION", [])
110            .expect("Unable to begin transaction!");
111        DbTxn {
112            committed: false,
113            conn,
114        }
115    }
116
117    /// This handles an error coming back from an sqlite event and dumps more information from it
118    fn sqlite_error(&self, msg: &str, error: &rusqlite::Error) -> CacheError {
119        error!(
120            "sqlite {} error: {:?} db_path={:?}",
121            msg,
122            error,
123            &self.conn.path()
124        );
125        CacheError::Sqlite
126    }
127
128    /// This handles an error coming back from an sqlite transaction and dumps a load of information from it
129    fn sqlite_transaction_error(
130        &self,
131        error: &rusqlite::Error,
132        _stmt: &rusqlite::Statement,
133    ) -> CacheError {
134        error!(
135            "sqlite transaction error={:?} db_path={:?}",
136            error,
137            &self.conn.path(),
138        );
139        // TODO: one day figure out if there's an easy way to dump the transaction without the token...
140        CacheError::Sqlite
141    }
142
143    fn get_db_version(&self, key: &str) -> i64 {
144        self.conn
145            .query_row(
146                "SELECT version FROM db_version_t WHERE id = :id",
147                &[(":id", key)],
148                |row| row.get(0),
149            )
150            .unwrap_or({
151                // The value is missing, default to 0.
152                0
153            })
154    }
155
156    fn set_db_version(&self, key: &str, v: i64) -> Result<(), CacheError> {
157        self.conn
158            .execute(
159                "INSERT OR REPLACE INTO db_version_t (id, version) VALUES(:id, :dbv)",
160                named_params! {
161                    ":id": &key,
162                    ":dbv": v,
163                },
164            )
165            .map(|_| ())
166            .map_err(|e| self.sqlite_error("set db_version_t", &e))
167    }
168
169    fn get_account_data_name(
170        &mut self,
171        account_id: &str,
172    ) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
173        let mut stmt = self.conn
174            .prepare(
175        "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
176            )
177            .map_err(|e| {
178                self.sqlite_error("select prepare", &e)
179            })?;
180
181        // Makes tuple (token, expiry)
182        let data_iter = stmt
183            .query_map([account_id], |row| Ok((row.get(0)?, row.get(1)?)))
184            .map_err(|e| self.sqlite_error("query_map failure", &e))?;
185        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
186            .map(|v| v.map_err(|e| self.sqlite_error("map failure", &e)))
187            .collect();
188        data
189    }
190
191    fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
192        let mut stmt = self
193            .conn
194            .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
195            .map_err(|e| self.sqlite_error("select prepare", &e))?;
196
197        // Makes tuple (token, expiry)
198        let data_iter = stmt
199            .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
200            .map_err(|e| self.sqlite_error("query_map", &e))?;
201        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
202            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
203            .collect();
204        data
205    }
206
207    fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
208        let mut stmt = self.conn
209            .prepare(
210                "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
211            )
212            .map_err(|e| {
213                self.sqlite_error("select prepare", &e)
214            })?;
215
216        // Makes tuple (token, expiry)
217        let data_iter = stmt
218            .query_map([grp_id], |row| Ok((row.get(0)?, row.get(1)?)))
219            .map_err(|e| self.sqlite_error("query_map", &e))?;
220        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
221            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
222            .collect();
223        data
224    }
225
226    fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
227        let mut stmt = self
228            .conn
229            .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
230            .map_err(|e| self.sqlite_error("select prepare", &e))?;
231
232        // Makes tuple (token, expiry)
233        let data_iter = stmt
234            .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
235            .map_err(|e| self.sqlite_error("query_map", &e))?;
236        let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
237            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
238            .collect();
239        data
240    }
241}
242
243impl KeyStoreTxn<'_, '_> {
244    pub fn get_tagged_hsm_key<K: DeserializeOwned>(
245        &mut self,
246        tag: &str,
247    ) -> Result<Option<K>, CacheError> {
248        self.db.get_tagged_hsm_key(tag)
249    }
250
251    pub fn insert_tagged_hsm_key<K: Serialize>(
252        &mut self,
253        tag: &str,
254        key: &K,
255    ) -> Result<(), CacheError> {
256        self.db.insert_tagged_hsm_key(tag, key)
257    }
258
259    pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
260        self.db.delete_tagged_hsm_key(tag)
261    }
262}
263
264impl DbTxn<'_> {
265    fn get_tagged_hsm_key<K: DeserializeOwned>(
266        &mut self,
267        tag: &str,
268    ) -> Result<Option<K>, CacheError> {
269        let mut stmt = self
270            .conn
271            .prepare("SELECT value FROM hsm_data_t WHERE key = :key")
272            .map_err(|e| self.sqlite_error("select prepare", &e))?;
273
274        let data: Option<Vec<u8>> = stmt
275            .query_row(
276                named_params! {
277                    ":key": tag
278                },
279                |row| row.get(0),
280            )
281            .optional()
282            .map_err(|e| self.sqlite_error("query_row", &e))?;
283
284        match data {
285            Some(d) => Ok(serde_json::from_slice(d.as_slice())
286                .map_err(|e| {
287                    error!("json error -> {:?}", e);
288                })
289                .ok()),
290            None => Ok(None),
291        }
292    }
293
294    fn insert_tagged_hsm_key<K: Serialize>(
295        &mut self,
296        tag: &str,
297        key: &K,
298    ) -> Result<(), CacheError> {
299        let data = serde_json::to_vec(key).map_err(|e| {
300            error!("json error -> {:?}", e);
301            CacheError::SerdeJson
302        })?;
303
304        let mut stmt = self
305            .conn
306            .prepare("INSERT OR REPLACE INTO hsm_data_t (key, value) VALUES (:key, :value)")
307            .map_err(|e| self.sqlite_error("prepare", &e))?;
308
309        stmt.execute(named_params! {
310            ":key": tag,
311            ":value": &data,
312        })
313        .map(|r| {
314            trace!("insert -> {:?}", r);
315        })
316        .map_err(|e| self.sqlite_error("execute", &e))
317    }
318
319    fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
320        self.conn
321            .execute(
322                "DELETE FROM hsm_data_t where key = :key",
323                named_params! {
324                    ":key": tag,
325                },
326            )
327            .map(|_| ())
328            .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))
329    }
330}
331
332impl DbTxn<'_> {
333    pub fn migrate(&mut self) -> Result<(), CacheError> {
334        self.conn.set_prepared_statement_cache_capacity(16);
335        self.conn
336            .prepare("PRAGMA journal_mode=WAL;")
337            .and_then(|mut wal_stmt| wal_stmt.query([]).map(|_| ()))
338            .map_err(|e| self.sqlite_error("account_t create", &e))?;
339
340        // This definition can never change.
341        self.conn
342            .execute(
343                "CREATE TABLE IF NOT EXISTS db_version_t (
344                    id TEXT PRIMARY KEY,
345                    version INTEGER
346                )",
347                [],
348            )
349            .map_err(|e| self.sqlite_error("db_version_t create", &e))?;
350
351        let db_version = self.get_db_version(DBV_MAIN);
352
353        if db_version < 1 {
354            // Setup two tables - one for accounts, one for groups.
355            // correctly index the columns.
356            // Optional pw hash field
357            self.conn
358                .execute(
359                    "CREATE TABLE IF NOT EXISTS account_t (
360                    uuid TEXT PRIMARY KEY,
361                    name TEXT NOT NULL UNIQUE,
362                    spn TEXT NOT NULL UNIQUE,
363                    gidnumber INTEGER NOT NULL UNIQUE,
364                    password BLOB,
365                    token BLOB NOT NULL,
366                    expiry NUMERIC NOT NULL
367                )
368                ",
369                    [],
370                )
371                .map_err(|e| self.sqlite_error("account_t create", &e))?;
372
373            self.conn
374                .execute(
375                    "CREATE TABLE IF NOT EXISTS group_t (
376                    uuid TEXT PRIMARY KEY,
377                    name TEXT NOT NULL UNIQUE,
378                    spn TEXT NOT NULL UNIQUE,
379                    gidnumber INTEGER NOT NULL UNIQUE,
380                    token BLOB NOT NULL,
381                    expiry NUMERIC NOT NULL
382                )
383                ",
384                    [],
385                )
386                .map_err(|e| self.sqlite_error("group_t create", &e))?;
387
388            // We defer group foreign keys here because we now manually cascade delete these when
389            // required. This is because insert or replace into will always delete then add
390            // which triggers this. So instead we defer and manually cascade.
391            //
392            // However, on accounts, we CAN delete cascade because accounts will always redefine
393            // their memberships on updates so this is safe to cascade on this direction.
394            self.conn
395                .execute(
396                    "CREATE TABLE IF NOT EXISTS memberof_t (
397                    g_uuid TEXT,
398                    a_uuid TEXT,
399                    FOREIGN KEY(g_uuid) REFERENCES group_t(uuid) DEFERRABLE INITIALLY DEFERRED,
400                    FOREIGN KEY(a_uuid) REFERENCES account_t(uuid) ON DELETE CASCADE
401                )
402                ",
403                    [],
404                )
405                .map_err(|e| self.sqlite_error("memberof_t create error", &e))?;
406
407            // Create the hsm_data store. These are all generally encrypted private
408            // keys, and the hsm structures will decrypt these as required.
409            self.conn
410                .execute(
411                    "CREATE TABLE IF NOT EXISTS hsm_int_t (
412                        key TEXT PRIMARY KEY,
413                        value BLOB NOT NULL
414                    )
415                    ",
416                    [],
417                )
418                .map_err(|e| self.sqlite_error("hsm_int_t create error", &e))?;
419
420            self.conn
421                .execute(
422                    "CREATE TABLE IF NOT EXISTS hsm_data_t (
423                        key TEXT PRIMARY KEY,
424                        value BLOB NOT NULL
425                    )
426                    ",
427                    [],
428                )
429                .map_err(|e| self.sqlite_error("hsm_data_t create error", &e))?;
430
431            // Since this is the 0th migration, we have to reset the HSM here.
432            self.clear_hsm()?;
433        }
434
435        self.set_db_version(DBV_MAIN, 1)?;
436
437        Ok(())
438    }
439
440    pub fn commit(mut self) -> Result<(), CacheError> {
441        if self.committed {
442            error!("Invalid state, SQL transaction was already committed!");
443            return Err(CacheError::TransactionInvalidState);
444        }
445        self.committed = true;
446
447        self.conn
448            .execute("COMMIT TRANSACTION", [])
449            .map(|_| ())
450            .map_err(|e| self.sqlite_error("commit", &e))
451    }
452
453    pub fn invalidate(&mut self) -> Result<(), CacheError> {
454        self.conn
455            .execute("UPDATE group_t SET expiry = 0", [])
456            .map_err(|e| self.sqlite_error("update group_t", &e))?;
457
458        self.conn
459            .execute("UPDATE account_t SET expiry = 0", [])
460            .map_err(|e| self.sqlite_error("update account_t", &e))?;
461
462        Ok(())
463    }
464
465    pub fn clear(&mut self) -> Result<(), CacheError> {
466        self.conn
467            .execute("DELETE FROM memberof_t", [])
468            .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
469
470        self.conn
471            .execute("DELETE FROM group_t", [])
472            .map_err(|e| self.sqlite_error("delete group_t", &e))?;
473
474        self.conn
475            .execute("DELETE FROM account_t", [])
476            .map_err(|e| self.sqlite_error("delete group_t", &e))?;
477
478        Ok(())
479    }
480
481    pub fn clear_hsm(&mut self) -> Result<(), CacheError> {
482        self.clear()?;
483
484        self.conn
485            .execute("DELETE FROM hsm_int_t", [])
486            .map_err(|e| self.sqlite_error("delete hsm_int_t", &e))?;
487
488        self.conn
489            .execute("DELETE FROM hsm_data_t", [])
490            .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))?;
491
492        Ok(())
493    }
494
495    pub fn get_hsm_machine_key(&mut self) -> Result<Option<LoadableMachineKey>, CacheError> {
496        let mut stmt = self
497            .conn
498            .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
499            .map_err(|e| self.sqlite_error("select prepare", &e))?;
500
501        let data: Option<Vec<u8>> = stmt
502            .query_row([], |row| row.get(0))
503            .optional()
504            .map_err(|e| self.sqlite_error("query_row", &e))?;
505
506        match data {
507            Some(d) => Ok(serde_json::from_slice(d.as_slice())
508                .map_err(|e| {
509                    error!("json error -> {:?}", e);
510                })
511                .ok()),
512            None => Ok(None),
513        }
514    }
515
516    pub fn insert_hsm_machine_key(
517        &mut self,
518        machine_key: &LoadableMachineKey,
519    ) -> Result<(), CacheError> {
520        let data = serde_json::to_vec(machine_key).map_err(|e| {
521            error!("insert_hsm_machine_key json error -> {:?}", e);
522            CacheError::SerdeJson
523        })?;
524
525        let mut stmt = self
526            .conn
527            .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
528            .map_err(|e| self.sqlite_error("prepare", &e))?;
529
530        stmt.execute(named_params! {
531            ":key": "mk",
532            ":value": &data,
533        })
534        .map(|r| {
535            trace!("insert -> {:?}", r);
536        })
537        .map_err(|e| self.sqlite_error("execute", &e))
538    }
539
540    pub fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacKey>, CacheError> {
541        let mut stmt = self
542            .conn
543            .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
544            .map_err(|e| self.sqlite_error("select prepare", &e))?;
545
546        let data: Option<Vec<u8>> = stmt
547            .query_row([], |row| row.get(0))
548            .optional()
549            .map_err(|e| self.sqlite_error("query_row", &e))?;
550
551        match data {
552            Some(d) => Ok(serde_json::from_slice(d.as_slice())
553                .map_err(|e| {
554                    error!("json error -> {:?}", e);
555                })
556                .ok()),
557            None => Ok(None),
558        }
559    }
560
561    pub fn insert_hsm_hmac_key(&mut self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError> {
562        let data = serde_json::to_vec(hmac_key).map_err(|e| {
563            error!("insert_hsm_hmac_key json error -> {:?}", e);
564            CacheError::SerdeJson
565        })?;
566
567        let mut stmt = self
568            .conn
569            .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
570            .map_err(|e| self.sqlite_error("prepare", &e))?;
571
572        stmt.execute(named_params! {
573            ":key": "hmac",
574            ":value": &data,
575        })
576        .map(|r| {
577            trace!("insert -> {:?}", r);
578        })
579        .map_err(|e| self.sqlite_error("execute", &e))
580    }
581
582    pub fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
583        let data = match account_id {
584            Id::Name(n) => self.get_account_data_name(n.as_str()),
585            Id::Gid(g) => self.get_account_data_gid(*g),
586        }?;
587
588        // Assert only one result?
589        if data.len() >= 2 {
590            error!("invalid db state, multiple entries matched query?");
591            return Err(CacheError::TooManyResults);
592        }
593
594        if let Some((token, expiry)) = data.first() {
595            // token convert with json.
596            // If this errors, we specifically return Ok(None) because that triggers
597            // the cache to refetch the token.
598            match serde_json::from_slice(token.as_slice()) {
599                Ok(t) => {
600                    let e = u64::try_from(*expiry).map_err(|e| {
601                        error!("u64 convert error -> {:?}", e);
602                        CacheError::Parse
603                    })?;
604                    Ok(Some((t, e)))
605                }
606                Err(e) => {
607                    warn!("recoverable - json error -> {:?}", e);
608                    Ok(None)
609                }
610            }
611        } else {
612            Ok(None)
613        }
614    }
615
616    pub fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
617        let mut stmt = self
618            .conn
619            .prepare("SELECT token FROM account_t")
620            .map_err(|e| self.sqlite_error("select prepare", &e))?;
621
622        let data_iter = stmt
623            .query_map([], |row| row.get(0))
624            .map_err(|e| self.sqlite_error("query_map", &e))?;
625        let data: Result<Vec<Vec<u8>>, _> = data_iter
626            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
627            .collect();
628
629        let data = data?;
630
631        Ok(data
632            .iter()
633            // We filter map here so that anything invalid is skipped.
634            .filter_map(|token| {
635                // token convert with json.
636                serde_json::from_slice(token.as_slice())
637                    .map_err(|e| {
638                        warn!("get_accounts json error -> {:?}", e);
639                    })
640                    .ok()
641            })
642            .collect())
643    }
644
645    pub fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
646        let data = serde_json::to_vec(account).map_err(|e| {
647            error!("update_account json error -> {:?}", e);
648            CacheError::SerdeJson
649        })?;
650        let expire = i64::try_from(expire).map_err(|e| {
651            error!("update_account i64 conversion error -> {:?}", e);
652            CacheError::Parse
653        })?;
654
655        // This is needed because sqlites 'insert or replace into', will null the password field
656        // if present, and upsert MUST match the exact conflicting column, so that means we have
657        // to manually manage the update or insert :( :(
658        let account_uuid = account.uuid.as_hyphenated().to_string();
659
660        // Find anything conflicting and purge it.
661        self.conn.execute("DELETE FROM account_t WHERE NOT uuid = :uuid AND (name = :name OR spn = :spn OR gidnumber = :gidnumber)",
662            named_params!{
663                ":uuid": &account_uuid,
664                ":name": &account.name,
665                ":spn": &account.spn,
666                ":gidnumber": &account.gidnumber,
667            }
668            )
669            .map_err(|e| {
670                self.sqlite_error("delete account_t duplicate", &e)
671            })
672            .map(|_| ())?;
673
674        let updated = self.conn.execute(
675                "UPDATE account_t SET name=:name, spn=:spn, gidnumber=:gidnumber, token=:token, expiry=:expiry WHERE uuid = :uuid",
676            named_params!{
677                ":uuid": &account_uuid,
678                ":name": &account.name,
679                ":spn": &account.spn,
680                ":gidnumber": &account.gidnumber,
681                ":token": &data,
682                ":expiry": &expire,
683            }
684            )
685            .map_err(|e| {
686                self.sqlite_error("delete account_t duplicate", &e)
687            })?;
688
689        if updated == 0 {
690            let mut stmt = self.conn
691                .prepare("INSERT INTO account_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry) ON CONFLICT(uuid) DO UPDATE SET name=excluded.name, spn=excluded.name, gidnumber=excluded.gidnumber, token=excluded.token, expiry=excluded.expiry")
692                .map_err(|e| {
693                    self.sqlite_error("prepare", &e)
694                })?;
695
696            stmt.execute(named_params! {
697                ":uuid": &account_uuid,
698                ":name": &account.name,
699                ":spn": &account.spn,
700                ":gidnumber": &account.gidnumber,
701                ":token": &data,
702                ":expiry": &expire,
703            })
704            .map(|r| {
705                trace!("insert -> {:?}", r);
706            })
707            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
708        }
709
710        // Now, we have to update the group memberships.
711
712        // First remove everything that already exists:
713        let mut stmt = self
714            .conn
715            .prepare("DELETE FROM memberof_t WHERE a_uuid = :a_uuid")
716            .map_err(|e| self.sqlite_error("prepare", &e))?;
717
718        stmt.execute([&account_uuid])
719            .map(|r| {
720                trace!("delete memberships -> {:?}", r);
721            })
722            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
723
724        let mut stmt = self
725            .conn
726            .prepare("INSERT INTO memberof_t (a_uuid, g_uuid) VALUES (:a_uuid, :g_uuid)")
727            .map_err(|e| self.sqlite_error("prepare", &e))?;
728        // Now for each group, add the relation.
729        account.groups.iter().try_for_each(|g| {
730            stmt.execute(named_params! {
731                ":a_uuid": &account_uuid,
732                ":g_uuid": &g.uuid.as_hyphenated().to_string(),
733            })
734            .map(|r| {
735                trace!("insert membership -> {:?}", r);
736            })
737            .map_err(|error| self.sqlite_transaction_error(&error, &stmt))
738        })
739    }
740
741    pub fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
742        let account_uuid = a_uuid.as_hyphenated().to_string();
743
744        self.conn
745            .execute(
746                "DELETE FROM memberof_t WHERE a_uuid = :a_uuid",
747                params![&account_uuid],
748            )
749            .map(|_| ())
750            .map_err(|e| self.sqlite_error("account_t memberof_t cascade delete", &e))?;
751
752        self.conn
753            .execute(
754                "DELETE FROM account_t WHERE uuid = :a_uuid",
755                params![&account_uuid],
756            )
757            .map(|_| ())
758            .map_err(|e| self.sqlite_error("account_t delete", &e))
759    }
760
761    pub fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
762        let data = match grp_id {
763            Id::Name(n) => self.get_group_data_name(n.as_str()),
764            Id::Gid(g) => self.get_group_data_gid(*g),
765        }?;
766
767        // Assert only one result?
768        if data.len() >= 2 {
769            error!("invalid db state, multiple entries matched query?");
770            return Err(CacheError::TooManyResults);
771        }
772
773        if let Some((token, expiry)) = data.first() {
774            // token convert with json.
775            // If this errors, we specifically return Ok(None) because that triggers
776            // the cache to refetch the token.
777            match serde_json::from_slice(token.as_slice()) {
778                Ok(t) => {
779                    let e = u64::try_from(*expiry).map_err(|e| {
780                        error!("u64 convert error -> {:?}", e);
781                        CacheError::Parse
782                    })?;
783                    Ok(Some((t, e)))
784                }
785                Err(e) => {
786                    warn!("recoverable - json error -> {:?}", e);
787                    Ok(None)
788                }
789            }
790        } else {
791            Ok(None)
792        }
793    }
794
795    pub fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
796        let mut stmt = self
797            .conn
798            .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
799            .map_err(|e| {
800                self.sqlite_error("select prepare", &e)
801            })?;
802
803        let data_iter = stmt
804            .query_map([g_uuid.as_hyphenated().to_string()], |row| row.get(0))
805            .map_err(|e| self.sqlite_error("query_map", &e))?;
806        let data: Result<Vec<Vec<u8>>, _> = data_iter
807            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
808            .collect();
809
810        let data = data?;
811
812        data.iter()
813            .map(|token| {
814                // token convert with json.
815                // trace!("{:?}", token);
816                serde_json::from_slice(token.as_slice()).map_err(|e| {
817                    error!("json error -> {:?}", e);
818                    CacheError::SerdeJson
819                })
820            })
821            .collect()
822    }
823
824    pub fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
825        let mut stmt = self
826            .conn
827            .prepare("SELECT token FROM group_t")
828            .map_err(|e| self.sqlite_error("select prepare", &e))?;
829
830        let data_iter = stmt
831            .query_map([], |row| row.get(0))
832            .map_err(|e| self.sqlite_error("query_map", &e))?;
833        let data: Result<Vec<Vec<u8>>, _> = data_iter
834            .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
835            .collect();
836
837        let data = data?;
838
839        Ok(data
840            .iter()
841            .filter_map(|token| {
842                // token convert with json.
843                // trace!("{:?}", token);
844                serde_json::from_slice(token.as_slice())
845                    .map_err(|e| {
846                        error!("json error -> {:?}", e);
847                    })
848                    .ok()
849            })
850            .collect())
851    }
852
853    pub fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
854        let data = serde_json::to_vec(grp).map_err(|e| {
855            error!("json error -> {:?}", e);
856            CacheError::SerdeJson
857        })?;
858        let expire = i64::try_from(expire).map_err(|e| {
859            error!("i64 convert error -> {:?}", e);
860            CacheError::Parse
861        })?;
862
863        let mut stmt = self.conn
864            .prepare("INSERT OR REPLACE INTO group_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry)")
865            .map_err(|e| {
866                self.sqlite_error("prepare", &e)
867            })?;
868
869        // We have to to-str uuid as the sqlite impl makes it a blob which breaks our selects in get.
870        stmt.execute(named_params! {
871            ":uuid": &grp.uuid.as_hyphenated().to_string(),
872            ":name": &grp.name,
873            ":spn": &grp.spn,
874            ":gidnumber": &grp.gidnumber,
875            ":token": &data,
876            ":expiry": &expire,
877        })
878        .map(|r| {
879            trace!("insert -> {:?}", r);
880        })
881        .map_err(|e| self.sqlite_error("execute", &e))
882    }
883
884    pub fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
885        let group_uuid = g_uuid.as_hyphenated().to_string();
886        self.conn
887            .execute(
888                "DELETE FROM memberof_t WHERE g_uuid = :g_uuid",
889                [&group_uuid],
890            )
891            .map(|_| ())
892            .map_err(|e| self.sqlite_error("group_t memberof_t cascade delete", &e))?;
893        self.conn
894            .execute("DELETE FROM group_t WHERE uuid = :g_uuid", [&group_uuid])
895            .map(|_| ())
896            .map_err(|e| self.sqlite_error("group_t delete", &e))
897    }
898}
899
900impl fmt::Debug for DbTxn<'_> {
901    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
902        write!(f, "DbTxn {{}}")
903    }
904}
905
906impl Drop for DbTxn<'_> {
907    // Abort
908    fn drop(&mut self) {
909        if !self.committed {
910            // trace!("Aborting BE WR txn");
911            #[allow(clippy::expect_used)]
912            self.conn
913                .execute("ROLLBACK TRANSACTION", [])
914                .expect("Unable to rollback transaction! Can not proceed!!!");
915        }
916    }
917}
918
919#[cfg(test)]
920mod tests {
921    use super::{Cache, Db};
922    use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
923
924    #[tokio::test]
925    async fn test_cache_db_account_basic() {
926        sketching::test_init();
927        let db = Db::new("").expect("failed to create.");
928        let mut dbtxn = db.write().await;
929        assert!(dbtxn.migrate().is_ok());
930
931        let mut ut1 = UserToken {
932            provider: ProviderOrigin::System,
933            name: "testuser".to_string(),
934            spn: "testuser@example.com".to_string(),
935            displayname: "Test User".to_string(),
936            gidnumber: 2000,
937            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
938            shell: None,
939            groups: Vec::new(),
940            sshkeys: vec!["key-a".to_string()],
941            valid: true,
942            extra_keys: Default::default(),
943        };
944
945        let id_name = Id::Name("testuser".to_string());
946        let id_name2 = Id::Name("testuser2".to_string());
947        let id_spn = Id::Name("testuser@example.com".to_string());
948        let id_spn2 = Id::Name("testuser2@example.com".to_string());
949        let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
950        let id_gid = Id::Gid(2000);
951
952        // test finding no account
953        let r1 = dbtxn.get_account(&id_name).unwrap();
954        assert!(r1.is_none());
955        let r2 = dbtxn.get_account(&id_spn).unwrap();
956        assert!(r2.is_none());
957        let r3 = dbtxn.get_account(&id_uuid).unwrap();
958        assert!(r3.is_none());
959        let r4 = dbtxn.get_account(&id_gid).unwrap();
960        assert!(r4.is_none());
961
962        // test adding an account
963        dbtxn.update_account(&ut1, 0).unwrap();
964
965        // test we can get it.
966        let r1 = dbtxn.get_account(&id_name).unwrap();
967        assert!(r1.is_some());
968        let r2 = dbtxn.get_account(&id_spn).unwrap();
969        assert!(r2.is_some());
970        let r3 = dbtxn.get_account(&id_uuid).unwrap();
971        assert!(r3.is_some());
972        let r4 = dbtxn.get_account(&id_gid).unwrap();
973        assert!(r4.is_some());
974
975        // test adding an account that was renamed
976        ut1.name = "testuser2".to_string();
977        ut1.spn = "testuser2@example.com".to_string();
978        dbtxn.update_account(&ut1, 0).unwrap();
979
980        // get the account
981        let r1 = dbtxn.get_account(&id_name).unwrap();
982        assert!(r1.is_none());
983        let r2 = dbtxn.get_account(&id_spn).unwrap();
984        assert!(r2.is_none());
985        let r1 = dbtxn.get_account(&id_name2).unwrap();
986        assert!(r1.is_some());
987        let r2 = dbtxn.get_account(&id_spn2).unwrap();
988        assert!(r2.is_some());
989        let r3 = dbtxn.get_account(&id_uuid).unwrap();
990        assert!(r3.is_some());
991        let r4 = dbtxn.get_account(&id_gid).unwrap();
992        assert!(r4.is_some());
993
994        // Clear cache
995        assert!(dbtxn.clear().is_ok());
996
997        // should be nothing
998        let r1 = dbtxn.get_account(&id_name2).unwrap();
999        assert!(r1.is_none());
1000        let r2 = dbtxn.get_account(&id_spn2).unwrap();
1001        assert!(r2.is_none());
1002        let r3 = dbtxn.get_account(&id_uuid).unwrap();
1003        assert!(r3.is_none());
1004        let r4 = dbtxn.get_account(&id_gid).unwrap();
1005        assert!(r4.is_none());
1006
1007        assert!(dbtxn.commit().is_ok());
1008    }
1009
1010    #[tokio::test]
1011    async fn test_cache_db_group_basic() {
1012        sketching::test_init();
1013        let db = Db::new("").expect("failed to create.");
1014        let mut dbtxn = db.write().await;
1015        assert!(dbtxn.migrate().is_ok());
1016
1017        let mut gt1 = GroupToken {
1018            provider: ProviderOrigin::System,
1019            name: "testgroup".to_string(),
1020            spn: "testgroup@example.com".to_string(),
1021            gidnumber: 2000,
1022            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1023            extra_keys: Default::default(),
1024        };
1025
1026        let id_name = Id::Name("testgroup".to_string());
1027        let id_name2 = Id::Name("testgroup2".to_string());
1028        let id_spn = Id::Name("testgroup@example.com".to_string());
1029        let id_spn2 = Id::Name("testgroup2@example.com".to_string());
1030        let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1031        let id_gid = Id::Gid(2000);
1032
1033        // test finding no group
1034        let r1 = dbtxn.get_group(&id_name).unwrap();
1035        assert!(r1.is_none());
1036        let r2 = dbtxn.get_group(&id_spn).unwrap();
1037        assert!(r2.is_none());
1038        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1039        assert!(r3.is_none());
1040        let r4 = dbtxn.get_group(&id_gid).unwrap();
1041        assert!(r4.is_none());
1042
1043        // test adding a group
1044        dbtxn.update_group(&gt1, 0).unwrap();
1045        let r1 = dbtxn.get_group(&id_name).unwrap();
1046        assert!(r1.is_some());
1047        let r2 = dbtxn.get_group(&id_spn).unwrap();
1048        assert!(r2.is_some());
1049        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1050        assert!(r3.is_some());
1051        let r4 = dbtxn.get_group(&id_gid).unwrap();
1052        assert!(r4.is_some());
1053
1054        // add a group via update
1055        gt1.name = "testgroup2".to_string();
1056        gt1.spn = "testgroup2@example.com".to_string();
1057        dbtxn.update_group(&gt1, 0).unwrap();
1058        let r1 = dbtxn.get_group(&id_name).unwrap();
1059        assert!(r1.is_none());
1060        let r2 = dbtxn.get_group(&id_spn).unwrap();
1061        assert!(r2.is_none());
1062        let r1 = dbtxn.get_group(&id_name2).unwrap();
1063        assert!(r1.is_some());
1064        let r2 = dbtxn.get_group(&id_spn2).unwrap();
1065        assert!(r2.is_some());
1066        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1067        assert!(r3.is_some());
1068        let r4 = dbtxn.get_group(&id_gid).unwrap();
1069        assert!(r4.is_some());
1070
1071        // clear cache
1072        assert!(dbtxn.clear().is_ok());
1073
1074        // should be nothing.
1075        let r1 = dbtxn.get_group(&id_name2).unwrap();
1076        assert!(r1.is_none());
1077        let r2 = dbtxn.get_group(&id_spn2).unwrap();
1078        assert!(r2.is_none());
1079        let r3 = dbtxn.get_group(&id_uuid).unwrap();
1080        assert!(r3.is_none());
1081        let r4 = dbtxn.get_group(&id_gid).unwrap();
1082        assert!(r4.is_none());
1083
1084        assert!(dbtxn.commit().is_ok());
1085    }
1086
1087    #[tokio::test]
1088    async fn test_cache_db_account_group_update() {
1089        sketching::test_init();
1090        let db = Db::new("").expect("failed to create.");
1091        let mut dbtxn = db.write().await;
1092        assert!(dbtxn.migrate().is_ok());
1093
1094        let gt1 = GroupToken {
1095            provider: ProviderOrigin::System,
1096            name: "testuser".to_string(),
1097            spn: "testuser@example.com".to_string(),
1098            gidnumber: 2000,
1099            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1100            extra_keys: Default::default(),
1101        };
1102
1103        let gt2 = GroupToken {
1104            provider: ProviderOrigin::System,
1105            name: "testgroup".to_string(),
1106            spn: "testgroup@example.com".to_string(),
1107            gidnumber: 2001,
1108            uuid: uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"),
1109            extra_keys: Default::default(),
1110        };
1111
1112        let mut ut1 = UserToken {
1113            provider: ProviderOrigin::System,
1114            name: "testuser".to_string(),
1115            spn: "testuser@example.com".to_string(),
1116            displayname: "Test User".to_string(),
1117            gidnumber: 2000,
1118            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1119            shell: None,
1120            groups: vec![gt1.clone(), gt2],
1121            sshkeys: vec!["key-a".to_string()],
1122            valid: true,
1123            extra_keys: Default::default(),
1124        };
1125
1126        // First, add the groups.
1127        ut1.groups.iter().for_each(|g| {
1128            dbtxn.update_group(g, 0).unwrap();
1129        });
1130
1131        // The add the account
1132        dbtxn.update_account(&ut1, 0).unwrap();
1133
1134        // Now, get the memberships of the two groups.
1135        let m1 = dbtxn
1136            .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1137            .unwrap();
1138        let m2 = dbtxn
1139            .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1140            .unwrap();
1141        assert_eq!(m1[0].name, "testuser");
1142        assert_eq!(m2[0].name, "testuser");
1143
1144        // Now alter testuser, remove gt2, update.
1145        ut1.groups = vec![gt1];
1146        dbtxn.update_account(&ut1, 0).unwrap();
1147
1148        // Check that the memberships have updated correctly.
1149        let m1 = dbtxn
1150            .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1151            .unwrap();
1152        let m2 = dbtxn
1153            .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1154            .unwrap();
1155        assert_eq!(m1[0].name, "testuser");
1156        assert!(m2.is_empty());
1157
1158        assert!(dbtxn.commit().is_ok());
1159    }
1160
1161    #[tokio::test]
1162    async fn test_cache_db_group_rename_duplicate() {
1163        sketching::test_init();
1164        let db = Db::new("").expect("failed to create.");
1165        let mut dbtxn = db.write().await;
1166        assert!(dbtxn.migrate().is_ok());
1167
1168        let mut gt1 = GroupToken {
1169            provider: ProviderOrigin::System,
1170            name: "testgroup".to_string(),
1171            spn: "testgroup@example.com".to_string(),
1172            gidnumber: 2000,
1173            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1174            extra_keys: Default::default(),
1175        };
1176
1177        let gt2 = GroupToken {
1178            provider: ProviderOrigin::System,
1179            name: "testgroup".to_string(),
1180            spn: "testgroup@example.com".to_string(),
1181            gidnumber: 2001,
1182            uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1183            extra_keys: Default::default(),
1184        };
1185
1186        let id_name = Id::Name("testgroup".to_string());
1187        let id_name2 = Id::Name("testgroup2".to_string());
1188
1189        // test finding no group
1190        let r1 = dbtxn.get_group(&id_name).unwrap();
1191        assert!(r1.is_none());
1192
1193        // test adding a group
1194        dbtxn.update_group(&gt1, 0).unwrap();
1195        let r0 = dbtxn.get_group(&id_name).unwrap();
1196        assert_eq!(
1197            r0.unwrap().0.uuid,
1198            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1199        );
1200
1201        // Do the "rename" of gt1 which is what would allow gt2 to be valid.
1202        gt1.name = "testgroup2".to_string();
1203        gt1.spn = "testgroup2@example.com".to_string();
1204        // Now, add gt2 which dups on gt1 name/spn.
1205        dbtxn.update_group(&gt2, 0).unwrap();
1206        let r2 = dbtxn.get_group(&id_name).unwrap();
1207        assert_eq!(
1208            r2.unwrap().0.uuid,
1209            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1210        );
1211        let r3 = dbtxn.get_group(&id_name2).unwrap();
1212        assert!(r3.is_none());
1213
1214        // Now finally update gt1
1215        dbtxn.update_group(&gt1, 0).unwrap();
1216
1217        // Both now coexist
1218        let r4 = dbtxn.get_group(&id_name).unwrap();
1219        assert_eq!(
1220            r4.unwrap().0.uuid,
1221            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1222        );
1223        let r5 = dbtxn.get_group(&id_name2).unwrap();
1224        assert_eq!(
1225            r5.unwrap().0.uuid,
1226            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1227        );
1228
1229        assert!(dbtxn.commit().is_ok());
1230    }
1231
1232    #[tokio::test]
1233    async fn test_cache_db_account_rename_duplicate() {
1234        sketching::test_init();
1235        let db = Db::new("").expect("failed to create.");
1236        let mut dbtxn = db.write().await;
1237        assert!(dbtxn.migrate().is_ok());
1238
1239        let mut ut1 = UserToken {
1240            provider: ProviderOrigin::System,
1241            name: "testuser".to_string(),
1242            spn: "testuser@example.com".to_string(),
1243            displayname: "Test User".to_string(),
1244            gidnumber: 2000,
1245            uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1246            shell: None,
1247            groups: Vec::new(),
1248            sshkeys: vec!["key-a".to_string()],
1249            valid: true,
1250            extra_keys: Default::default(),
1251        };
1252
1253        let ut2 = UserToken {
1254            provider: ProviderOrigin::System,
1255            name: "testuser".to_string(),
1256            spn: "testuser@example.com".to_string(),
1257            displayname: "Test User".to_string(),
1258            gidnumber: 2001,
1259            uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1260            shell: None,
1261            groups: Vec::new(),
1262            sshkeys: vec!["key-a".to_string()],
1263            valid: true,
1264            extra_keys: Default::default(),
1265        };
1266
1267        let id_name = Id::Name("testuser".to_string());
1268        let id_name2 = Id::Name("testuser2".to_string());
1269
1270        // test finding no account
1271        let r1 = dbtxn.get_account(&id_name).unwrap();
1272        assert!(r1.is_none());
1273
1274        // test adding an account
1275        dbtxn.update_account(&ut1, 0).unwrap();
1276        let r0 = dbtxn.get_account(&id_name).unwrap();
1277        assert_eq!(
1278            r0.unwrap().0.uuid,
1279            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1280        );
1281
1282        // Do the "rename" of gt1 which is what would allow gt2 to be valid.
1283        ut1.name = "testuser2".to_string();
1284        ut1.spn = "testuser2@example.com".to_string();
1285        // Now, add gt2 which dups on gt1 name/spn.
1286        dbtxn.update_account(&ut2, 0).unwrap();
1287        let r2 = dbtxn.get_account(&id_name).unwrap();
1288        assert_eq!(
1289            r2.unwrap().0.uuid,
1290            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1291        );
1292        let r3 = dbtxn.get_account(&id_name2).unwrap();
1293        assert!(r3.is_none());
1294
1295        // Now finally update gt1
1296        dbtxn.update_account(&ut1, 0).unwrap();
1297
1298        // Both now coexist
1299        let r4 = dbtxn.get_account(&id_name).unwrap();
1300        assert_eq!(
1301            r4.unwrap().0.uuid,
1302            uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1303        );
1304        let r5 = dbtxn.get_account(&id_name2).unwrap();
1305        assert_eq!(
1306            r5.unwrap().0.uuid,
1307            uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1308        );
1309
1310        assert!(dbtxn.commit().is_ok());
1311    }
1312}