1use crate::idprovider::interface::{GroupToken, Id, UserToken};
2use async_trait::async_trait;
3use kanidm_hsm_crypto::structures::{LoadableHmacS256Key, LoadableStorageKey};
4use libc::umask;
5use rusqlite::{Connection, OptionalExtension};
6use serde::{de::DeserializeOwned, Serialize};
7use std::convert::TryFrom;
8use std::fmt;
9use tokio::sync::{Mutex, MutexGuard};
10use uuid::Uuid;
11
12const DBV_MAIN: &str = "main";
13const CACHE_SIZE: usize = 32 * ((1024 * 1024) / 4096);
16
17#[async_trait]
18pub trait Cache {
19 type Txn<'db>
20 where
21 Self: 'db;
22
23 async fn write<'db>(&'db self) -> Self::Txn<'db>;
24}
25
26#[async_trait]
27pub trait KeyStore {
28 type Txn<'db>
29 where
30 Self: 'db;
31
32 async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
33}
34
35#[derive(Debug)]
36pub enum CacheError {
37 Cryptography,
38 SerdeJson,
39 Parse,
40 Sqlite,
41 TooManyResults,
42 TransactionInvalidState,
43 Tpm,
44}
45
46pub struct Db {
47 conn: Mutex<Connection>,
48}
49
50pub struct DbTxn<'a> {
51 conn: MutexGuard<'a, Connection>,
52 committed: bool,
53}
54
55pub struct KeyStoreTxn<'a, 'b> {
56 db: &'b mut DbTxn<'a>,
57}
58
59impl<'a, 'b> From<&'b mut DbTxn<'a>> for KeyStoreTxn<'a, 'b> {
60 fn from(db: &'b mut DbTxn<'a>) -> Self {
61 Self { db }
62 }
63}
64
65#[derive(Debug)]
66pub enum DbError {
68 Sqlite,
69 Tpm,
70}
71
72impl Db {
73 pub fn new(path: &str) -> Result<Self, DbError> {
74 let before = unsafe { umask(0o0027) };
75 let conn = Connection::open(path).map_err(|e| {
76 error!(err = ?e, "rusqulite error");
77 DbError::Sqlite
78 })?;
79 let _ = unsafe { umask(before) };
80
81 conn.pragma_update(None, "journal_mode", "WAL")
83 .map_err(|error| {
84 error!(
85 "sqlite journal_mode=WAL error: {:?} db_path={:?}",
86 error, path
87 );
88 DbError::Sqlite
89 })?;
90
91 conn.pragma_update(None, "cache_size", CACHE_SIZE)
92 .map_err(|error| {
93 error!(
94 "sqlite cache_size={} error: {:?} db_path={:?}",
95 CACHE_SIZE, error, path
96 );
97 DbError::Sqlite
98 })?;
99
100 conn.set_prepared_statement_cache_capacity(32);
101
102 Ok(Db {
103 conn: Mutex::new(conn),
104 })
105 }
106}
107
108#[async_trait]
109impl Cache for Db {
110 type Txn<'db> = DbTxn<'db>;
111
112 #[allow(clippy::expect_used)]
113 async fn write<'db>(&'db self) -> Self::Txn<'db> {
114 let conn = self.conn.lock().await;
115 DbTxn::new(conn)
116 }
117}
118
119impl fmt::Debug for Db {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 write!(f, "Db {{}}")
122 }
123}
124
125impl<'a> DbTxn<'a> {
126 fn new(conn: MutexGuard<'a, Connection>) -> Self {
127 #[allow(clippy::expect_used)]
130 conn.execute("BEGIN TRANSACTION", [])
131 .expect("Unable to begin transaction!");
132 DbTxn {
133 committed: false,
134 conn,
135 }
136 }
137
138 fn sqlite_error(&self, msg: &str, error: &rusqlite::Error) -> CacheError {
140 error!(
141 "sqlite {} error: {:?} db_path={:?}",
142 msg,
143 error,
144 &self.conn.path()
145 );
146 CacheError::Sqlite
147 }
148
149 fn sqlite_transaction_error(
151 &self,
152 error: &rusqlite::Error,
153 _stmt: &rusqlite::Statement,
154 ) -> CacheError {
155 error!(
156 "sqlite transaction error={:?} db_path={:?}",
157 error,
158 &self.conn.path(),
159 );
160 CacheError::Sqlite
162 }
163
164 fn get_db_version(&self, key: &str) -> i64 {
165 self.conn
166 .query_row(
167 "SELECT version FROM db_version_t WHERE id = :id",
168 &[(":id", key)],
169 |row| row.get(0),
170 )
171 .unwrap_or({
172 0
174 })
175 }
176
177 fn set_db_version(&self, key: &str, v: i64) -> Result<(), CacheError> {
178 self.conn
179 .execute(
180 "INSERT OR REPLACE INTO db_version_t (id, version) VALUES(:id, :dbv)",
181 named_params! {
182 ":id": &key,
183 ":dbv": v,
184 },
185 )
186 .map(|_| ())
187 .map_err(|e| self.sqlite_error("set db_version_t", &e))
188 }
189
190 fn get_account_data_name(
191 &mut self,
192 account_id: &str,
193 ) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
194 let mut stmt = self.conn
195 .prepare(
196 "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
197 )
198 .map_err(|e| {
199 self.sqlite_error("select prepare", &e)
200 })?;
201
202 let data_iter = stmt
204 .query_map([account_id], |row| Ok((row.get(0)?, row.get(1)?)))
205 .map_err(|e| self.sqlite_error("query_map failure", &e))?;
206 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
207 .map(|v| v.map_err(|e| self.sqlite_error("map failure", &e)))
208 .collect();
209 data
210 }
211
212 fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
213 let mut stmt = self
214 .conn
215 .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
216 .map_err(|e| self.sqlite_error("select prepare", &e))?;
217
218 let data_iter = stmt
220 .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
221 .map_err(|e| self.sqlite_error("query_map", &e))?;
222 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
223 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
224 .collect();
225 data
226 }
227
228 fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
229 let mut stmt = self.conn
230 .prepare(
231 "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
232 )
233 .map_err(|e| {
234 self.sqlite_error("select prepare", &e)
235 })?;
236
237 let data_iter = stmt
239 .query_map([grp_id], |row| Ok((row.get(0)?, row.get(1)?)))
240 .map_err(|e| self.sqlite_error("query_map", &e))?;
241 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
242 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
243 .collect();
244 data
245 }
246
247 fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
248 let mut stmt = self
249 .conn
250 .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
251 .map_err(|e| self.sqlite_error("select prepare", &e))?;
252
253 let data_iter = stmt
255 .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
256 .map_err(|e| self.sqlite_error("query_map", &e))?;
257 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
258 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
259 .collect();
260 data
261 }
262}
263
264impl KeyStoreTxn<'_, '_> {
265 pub fn get_tagged_hsm_key<K: DeserializeOwned>(
266 &mut self,
267 tag: &str,
268 ) -> Result<Option<K>, CacheError> {
269 self.db.get_tagged_hsm_key(tag)
270 }
271
272 pub fn insert_tagged_hsm_key<K: Serialize>(
273 &mut self,
274 tag: &str,
275 key: &K,
276 ) -> Result<(), CacheError> {
277 self.db.insert_tagged_hsm_key(tag, key)
278 }
279
280 pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
281 self.db.delete_tagged_hsm_key(tag)
282 }
283}
284
285impl DbTxn<'_> {
286 fn get_tagged_hsm_key<K: DeserializeOwned>(
287 &mut self,
288 tag: &str,
289 ) -> Result<Option<K>, CacheError> {
290 let mut stmt = self
291 .conn
292 .prepare("SELECT value FROM hsm_data_t WHERE key = :key")
293 .map_err(|e| self.sqlite_error("select prepare", &e))?;
294
295 let data: Option<Vec<u8>> = stmt
296 .query_row(
297 named_params! {
298 ":key": tag
299 },
300 |row| row.get(0),
301 )
302 .optional()
303 .map_err(|e| self.sqlite_error("query_row", &e))?;
304
305 match data {
306 Some(d) => Ok(serde_json::from_slice(d.as_slice())
307 .map_err(|e| {
308 error!("json error -> {:?}", e);
309 })
310 .ok()),
311 None => Ok(None),
312 }
313 }
314
315 fn insert_tagged_hsm_key<K: Serialize>(
316 &mut self,
317 tag: &str,
318 key: &K,
319 ) -> Result<(), CacheError> {
320 let data = serde_json::to_vec(key).map_err(|e| {
321 error!("json error -> {:?}", e);
322 CacheError::SerdeJson
323 })?;
324
325 let mut stmt = self
326 .conn
327 .prepare("INSERT OR REPLACE INTO hsm_data_t (key, value) VALUES (:key, :value)")
328 .map_err(|e| self.sqlite_error("prepare", &e))?;
329
330 stmt.execute(named_params! {
331 ":key": tag,
332 ":value": &data,
333 })
334 .map(|r| {
335 trace!("insert -> {:?}", r);
336 })
337 .map_err(|e| self.sqlite_error("execute", &e))
338 }
339
340 fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
341 self.conn
342 .execute(
343 "DELETE FROM hsm_data_t where key = :key",
344 named_params! {
345 ":key": tag,
346 },
347 )
348 .map(|_| ())
349 .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))
350 }
351}
352
353impl DbTxn<'_> {
354 pub fn migrate(&mut self) -> Result<(), CacheError> {
355 self.conn
357 .execute(
358 "CREATE TABLE IF NOT EXISTS db_version_t (
359 id TEXT PRIMARY KEY,
360 version INTEGER
361 )",
362 [],
363 )
364 .map_err(|e| self.sqlite_error("db_version_t create", &e))?;
365
366 let db_version = self.get_db_version(DBV_MAIN);
367
368 if db_version < 1 {
369 self.conn
373 .execute(
374 "CREATE TABLE IF NOT EXISTS account_t (
375 uuid TEXT PRIMARY KEY,
376 name TEXT NOT NULL UNIQUE,
377 spn TEXT NOT NULL UNIQUE,
378 gidnumber INTEGER NOT NULL UNIQUE,
379 password BLOB,
380 token BLOB NOT NULL,
381 expiry NUMERIC NOT NULL
382 )
383 ",
384 [],
385 )
386 .map_err(|e| self.sqlite_error("account_t create", &e))?;
387
388 self.conn
389 .execute(
390 "CREATE TABLE IF NOT EXISTS group_t (
391 uuid TEXT PRIMARY KEY,
392 name TEXT NOT NULL UNIQUE,
393 spn TEXT NOT NULL UNIQUE,
394 gidnumber INTEGER NOT NULL UNIQUE,
395 token BLOB NOT NULL,
396 expiry NUMERIC NOT NULL
397 )
398 ",
399 [],
400 )
401 .map_err(|e| self.sqlite_error("group_t create", &e))?;
402
403 self.conn
410 .execute(
411 "CREATE TABLE IF NOT EXISTS memberof_t (
412 g_uuid TEXT,
413 a_uuid TEXT,
414 FOREIGN KEY(g_uuid) REFERENCES group_t(uuid) DEFERRABLE INITIALLY DEFERRED,
415 FOREIGN KEY(a_uuid) REFERENCES account_t(uuid) ON DELETE CASCADE
416 )
417 ",
418 [],
419 )
420 .map_err(|e| self.sqlite_error("memberof_t create error", &e))?;
421
422 self.conn
425 .execute(
426 "CREATE TABLE IF NOT EXISTS hsm_int_t (
427 key TEXT PRIMARY KEY,
428 value BLOB NOT NULL
429 )
430 ",
431 [],
432 )
433 .map_err(|e| self.sqlite_error("hsm_int_t create error", &e))?;
434
435 self.conn
436 .execute(
437 "CREATE TABLE IF NOT EXISTS hsm_data_t (
438 key TEXT PRIMARY KEY,
439 value BLOB NOT NULL
440 )
441 ",
442 [],
443 )
444 .map_err(|e| self.sqlite_error("hsm_data_t create error", &e))?;
445
446 self.clear_hsm()?;
448 }
449
450 self.set_db_version(DBV_MAIN, 1)?;
451
452 Ok(())
453 }
454
455 pub fn commit(mut self) -> Result<(), CacheError> {
456 if self.committed {
457 error!("Invalid state, SQL transaction was already committed!");
458 return Err(CacheError::TransactionInvalidState);
459 }
460 self.committed = true;
461
462 self.conn
463 .execute("COMMIT TRANSACTION", [])
464 .map(|_| ())
465 .map_err(|e| self.sqlite_error("commit", &e))
466 }
467
468 pub fn invalidate(&mut self) -> Result<(), CacheError> {
469 self.conn
470 .execute("UPDATE group_t SET expiry = 0", [])
471 .map_err(|e| self.sqlite_error("update group_t", &e))?;
472
473 self.conn
474 .execute("UPDATE account_t SET expiry = 0", [])
475 .map_err(|e| self.sqlite_error("update account_t", &e))?;
476
477 Ok(())
478 }
479
480 pub fn clear(&mut self) -> Result<(), CacheError> {
481 self.conn
482 .execute("DELETE FROM memberof_t", [])
483 .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
484
485 self.conn
486 .execute("DELETE FROM group_t", [])
487 .map_err(|e| self.sqlite_error("delete group_t", &e))?;
488
489 self.conn
490 .execute("DELETE FROM account_t", [])
491 .map_err(|e| self.sqlite_error("delete group_t", &e))?;
492
493 Ok(())
494 }
495
496 pub fn clear_hsm(&mut self) -> Result<(), CacheError> {
497 self.clear()?;
498
499 self.conn
500 .execute("DELETE FROM hsm_int_t", [])
501 .map_err(|e| self.sqlite_error("delete hsm_int_t", &e))?;
502
503 self.conn
504 .execute("DELETE FROM hsm_data_t", [])
505 .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))?;
506
507 Ok(())
508 }
509
510 pub fn get_hsm_root_storage_key(&mut self) -> Result<Option<LoadableStorageKey>, CacheError> {
511 let mut stmt = self
512 .conn
513 .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
514 .map_err(|e| self.sqlite_error("select prepare", &e))?;
515
516 let data: Option<Vec<u8>> = stmt
517 .query_row([], |row| row.get(0))
518 .optional()
519 .map_err(|e| self.sqlite_error("query_row", &e))?;
520
521 match data {
522 Some(d) => Ok(serde_json::from_slice(d.as_slice())
523 .map_err(|e| {
524 error!("json error -> {:?}", e);
525 })
526 .ok()),
527 None => Ok(None),
528 }
529 }
530
531 pub fn insert_hsm_root_storage_key(
532 &mut self,
533 machine_key: &LoadableStorageKey,
534 ) -> Result<(), CacheError> {
535 let data = serde_json::to_vec(machine_key).map_err(|e| {
536 error!("insert_hsm_machine_key json error -> {:?}", e);
537 CacheError::SerdeJson
538 })?;
539
540 let mut stmt = self
541 .conn
542 .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
543 .map_err(|e| self.sqlite_error("prepare", &e))?;
544
545 stmt.execute(named_params! {
546 ":key": "mk",
547 ":value": &data,
548 })
549 .map(|r| {
550 trace!("insert -> {:?}", r);
551 })
552 .map_err(|e| self.sqlite_error("execute", &e))
553 }
554
555 pub fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacS256Key>, CacheError> {
556 let mut stmt = self
557 .conn
558 .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
559 .map_err(|e| self.sqlite_error("select prepare", &e))?;
560
561 let data: Option<Vec<u8>> = stmt
562 .query_row([], |row| row.get(0))
563 .optional()
564 .map_err(|e| self.sqlite_error("query_row", &e))?;
565
566 match data {
567 Some(d) => Ok(serde_json::from_slice(d.as_slice())
568 .map_err(|e| {
569 error!("json error -> {:?}", e);
570 })
571 .ok()),
572 None => Ok(None),
573 }
574 }
575
576 pub fn insert_hsm_hmac_key(
577 &mut self,
578 hmac_key: &LoadableHmacS256Key,
579 ) -> Result<(), CacheError> {
580 let data = serde_json::to_vec(hmac_key).map_err(|e| {
581 error!("insert_hsm_hmac_key json error -> {:?}", e);
582 CacheError::SerdeJson
583 })?;
584
585 let mut stmt = self
586 .conn
587 .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
588 .map_err(|e| self.sqlite_error("prepare", &e))?;
589
590 stmt.execute(named_params! {
591 ":key": "hmac",
592 ":value": &data,
593 })
594 .map(|r| {
595 trace!("insert -> {:?}", r);
596 })
597 .map_err(|e| self.sqlite_error("execute", &e))
598 }
599
600 pub fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
601 let data = match account_id {
602 Id::Name(n) => self.get_account_data_name(n.as_str()),
603 Id::Gid(g) => self.get_account_data_gid(*g),
604 }?;
605
606 if data.len() >= 2 {
608 error!("invalid db state, multiple entries matched query?");
609 return Err(CacheError::TooManyResults);
610 }
611
612 if let Some((token, expiry)) = data.first() {
613 match serde_json::from_slice(token.as_slice()) {
617 Ok(t) => {
618 let e = u64::try_from(*expiry).map_err(|e| {
619 error!("u64 convert error -> {:?}", e);
620 CacheError::Parse
621 })?;
622 Ok(Some((t, e)))
623 }
624 Err(e) => {
625 warn!("recoverable - json error -> {:?}", e);
626 Ok(None)
627 }
628 }
629 } else {
630 Ok(None)
631 }
632 }
633
634 pub fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
635 let mut stmt = self
636 .conn
637 .prepare("SELECT token FROM account_t")
638 .map_err(|e| self.sqlite_error("select prepare", &e))?;
639
640 let data_iter = stmt
641 .query_map([], |row| row.get(0))
642 .map_err(|e| self.sqlite_error("query_map", &e))?;
643 let data: Result<Vec<Vec<u8>>, _> = data_iter
644 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
645 .collect();
646
647 let data = data?;
648
649 Ok(data
650 .iter()
651 .filter_map(|token| {
653 serde_json::from_slice(token.as_slice())
655 .map_err(|e| {
656 warn!("get_accounts json error -> {:?}", e);
657 })
658 .ok()
659 })
660 .collect())
661 }
662
663 pub fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
664 let data = serde_json::to_vec(account).map_err(|e| {
665 error!("update_account json error -> {:?}", e);
666 CacheError::SerdeJson
667 })?;
668 let expire = i64::try_from(expire).map_err(|e| {
669 error!("update_account i64 conversion error -> {:?}", e);
670 CacheError::Parse
671 })?;
672
673 let account_uuid = account.uuid.as_hyphenated().to_string();
677
678 self.conn.execute("DELETE FROM account_t WHERE NOT uuid = :uuid AND (name = :name OR spn = :spn OR gidnumber = :gidnumber)",
680 named_params!{
681 ":uuid": &account_uuid,
682 ":name": &account.name,
683 ":spn": &account.spn,
684 ":gidnumber": &account.gidnumber,
685 }
686 )
687 .map_err(|e| {
688 self.sqlite_error("delete account_t duplicate", &e)
689 })
690 .map(|_| ())?;
691
692 let updated = self.conn.execute(
693 "UPDATE account_t SET name=:name, spn=:spn, gidnumber=:gidnumber, token=:token, expiry=:expiry WHERE uuid = :uuid",
694 named_params!{
695 ":uuid": &account_uuid,
696 ":name": &account.name,
697 ":spn": &account.spn,
698 ":gidnumber": &account.gidnumber,
699 ":token": &data,
700 ":expiry": &expire,
701 }
702 )
703 .map_err(|e| {
704 self.sqlite_error("delete account_t duplicate", &e)
705 })?;
706
707 if updated == 0 {
708 let mut stmt = self.conn
709 .prepare("INSERT INTO account_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry) ON CONFLICT(uuid) DO UPDATE SET name=excluded.name, spn=excluded.name, gidnumber=excluded.gidnumber, token=excluded.token, expiry=excluded.expiry")
710 .map_err(|e| {
711 self.sqlite_error("prepare", &e)
712 })?;
713
714 stmt.execute(named_params! {
715 ":uuid": &account_uuid,
716 ":name": &account.name,
717 ":spn": &account.spn,
718 ":gidnumber": &account.gidnumber,
719 ":token": &data,
720 ":expiry": &expire,
721 })
722 .map(|r| {
723 trace!("insert -> {:?}", r);
724 })
725 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
726 }
727
728 let mut stmt = self
732 .conn
733 .prepare("DELETE FROM memberof_t WHERE a_uuid = :a_uuid")
734 .map_err(|e| self.sqlite_error("prepare", &e))?;
735
736 stmt.execute([&account_uuid])
737 .map(|r| {
738 trace!("delete memberships -> {:?}", r);
739 })
740 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
741
742 let mut stmt = self
743 .conn
744 .prepare("INSERT INTO memberof_t (a_uuid, g_uuid) VALUES (:a_uuid, :g_uuid)")
745 .map_err(|e| self.sqlite_error("prepare", &e))?;
746 account.groups.iter().try_for_each(|g| {
748 stmt.execute(named_params! {
749 ":a_uuid": &account_uuid,
750 ":g_uuid": &g.uuid.as_hyphenated().to_string(),
751 })
752 .map(|r| {
753 trace!("insert membership -> {:?}", r);
754 })
755 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))
756 })
757 }
758
759 pub fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
760 let account_uuid = a_uuid.as_hyphenated().to_string();
761
762 self.conn
763 .execute(
764 "DELETE FROM memberof_t WHERE a_uuid = :a_uuid",
765 params![&account_uuid],
766 )
767 .map(|_| ())
768 .map_err(|e| self.sqlite_error("account_t memberof_t cascade delete", &e))?;
769
770 self.conn
771 .execute(
772 "DELETE FROM account_t WHERE uuid = :a_uuid",
773 params![&account_uuid],
774 )
775 .map(|_| ())
776 .map_err(|e| self.sqlite_error("account_t delete", &e))
777 }
778
779 pub fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
780 let data = match grp_id {
781 Id::Name(n) => self.get_group_data_name(n.as_str()),
782 Id::Gid(g) => self.get_group_data_gid(*g),
783 }?;
784
785 if data.len() >= 2 {
787 error!("invalid db state, multiple entries matched query?");
788 return Err(CacheError::TooManyResults);
789 }
790
791 if let Some((token, expiry)) = data.first() {
792 match serde_json::from_slice(token.as_slice()) {
796 Ok(t) => {
797 let e = u64::try_from(*expiry).map_err(|e| {
798 error!("u64 convert error -> {:?}", e);
799 CacheError::Parse
800 })?;
801 Ok(Some((t, e)))
802 }
803 Err(e) => {
804 warn!("recoverable - json error -> {:?}", e);
805 Ok(None)
806 }
807 }
808 } else {
809 Ok(None)
810 }
811 }
812
813 pub fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
814 let mut stmt = self
815 .conn
816 .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
817 .map_err(|e| {
818 self.sqlite_error("select prepare", &e)
819 })?;
820
821 let data_iter = stmt
822 .query_map([g_uuid.as_hyphenated().to_string()], |row| row.get(0))
823 .map_err(|e| self.sqlite_error("query_map", &e))?;
824 let data: Result<Vec<Vec<u8>>, _> = data_iter
825 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
826 .collect();
827
828 let data = data?;
829
830 data.iter()
831 .map(|token| {
832 serde_json::from_slice(token.as_slice()).map_err(|e| {
835 error!("json error -> {:?}", e);
836 CacheError::SerdeJson
837 })
838 })
839 .collect()
840 }
841
842 pub fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
843 let mut stmt = self
844 .conn
845 .prepare("SELECT token FROM group_t")
846 .map_err(|e| self.sqlite_error("select prepare", &e))?;
847
848 let data_iter = stmt
849 .query_map([], |row| row.get(0))
850 .map_err(|e| self.sqlite_error("query_map", &e))?;
851 let data: Result<Vec<Vec<u8>>, _> = data_iter
852 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
853 .collect();
854
855 let data = data?;
856
857 Ok(data
858 .iter()
859 .filter_map(|token| {
860 serde_json::from_slice(token.as_slice())
863 .map_err(|e| {
864 error!("json error -> {:?}", e);
865 })
866 .ok()
867 })
868 .collect())
869 }
870
871 pub fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
872 let data = serde_json::to_vec(grp).map_err(|e| {
873 error!("json error -> {:?}", e);
874 CacheError::SerdeJson
875 })?;
876 let expire = i64::try_from(expire).map_err(|e| {
877 error!("i64 convert error -> {:?}", e);
878 CacheError::Parse
879 })?;
880
881 let mut stmt = self.conn
882 .prepare("INSERT OR REPLACE INTO group_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry)")
883 .map_err(|e| {
884 self.sqlite_error("prepare", &e)
885 })?;
886
887 stmt.execute(named_params! {
889 ":uuid": &grp.uuid.as_hyphenated().to_string(),
890 ":name": &grp.name,
891 ":spn": &grp.spn,
892 ":gidnumber": &grp.gidnumber,
893 ":token": &data,
894 ":expiry": &expire,
895 })
896 .map(|r| {
897 trace!("insert -> {:?}", r);
898 })
899 .map_err(|e| self.sqlite_error("execute", &e))
900 }
901
902 pub fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
903 let group_uuid = g_uuid.as_hyphenated().to_string();
904 self.conn
905 .execute(
906 "DELETE FROM memberof_t WHERE g_uuid = :g_uuid",
907 [&group_uuid],
908 )
909 .map(|_| ())
910 .map_err(|e| self.sqlite_error("group_t memberof_t cascade delete", &e))?;
911 self.conn
912 .execute("DELETE FROM group_t WHERE uuid = :g_uuid", [&group_uuid])
913 .map(|_| ())
914 .map_err(|e| self.sqlite_error("group_t delete", &e))
915 }
916}
917
918impl fmt::Debug for DbTxn<'_> {
919 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
920 write!(f, "DbTxn {{}}")
921 }
922}
923
924impl Drop for DbTxn<'_> {
925 fn drop(&mut self) {
927 if !self.committed {
928 #[allow(clippy::expect_used)]
930 self.conn
931 .execute("ROLLBACK TRANSACTION", [])
932 .expect("Unable to rollback transaction! Can not proceed!!!");
933 }
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::{Cache, Db};
940 use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
941
942 #[tokio::test]
943 async fn test_cache_db_account_basic() {
944 sketching::test_init();
945 let db = Db::new("").expect("failed to create.");
946 let mut dbtxn = db.write().await;
947 assert!(dbtxn.migrate().is_ok());
948
949 let mut ut1 = UserToken {
950 provider: ProviderOrigin::System,
951 name: "testuser".to_string(),
952 spn: "testuser@example.com".to_string(),
953 displayname: "Test User".to_string(),
954 gidnumber: 2000,
955 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
956 shell: None,
957 groups: Vec::new(),
958 sshkeys: vec!["key-a".to_string()],
959 valid: true,
960 extra_keys: Default::default(),
961 };
962
963 let id_name = Id::Name("testuser".to_string());
964 let id_name2 = Id::Name("testuser2".to_string());
965 let id_spn = Id::Name("testuser@example.com".to_string());
966 let id_spn2 = Id::Name("testuser2@example.com".to_string());
967 let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
968 let id_gid = Id::Gid(2000);
969
970 let r1 = dbtxn.get_account(&id_name).unwrap();
972 assert!(r1.is_none());
973 let r2 = dbtxn.get_account(&id_spn).unwrap();
974 assert!(r2.is_none());
975 let r3 = dbtxn.get_account(&id_uuid).unwrap();
976 assert!(r3.is_none());
977 let r4 = dbtxn.get_account(&id_gid).unwrap();
978 assert!(r4.is_none());
979
980 dbtxn.update_account(&ut1, 0).unwrap();
982
983 let r1 = dbtxn.get_account(&id_name).unwrap();
985 assert!(r1.is_some());
986 let r2 = dbtxn.get_account(&id_spn).unwrap();
987 assert!(r2.is_some());
988 let r3 = dbtxn.get_account(&id_uuid).unwrap();
989 assert!(r3.is_some());
990 let r4 = dbtxn.get_account(&id_gid).unwrap();
991 assert!(r4.is_some());
992
993 ut1.name = "testuser2".to_string();
995 ut1.spn = "testuser2@example.com".to_string();
996 dbtxn.update_account(&ut1, 0).unwrap();
997
998 let r1 = dbtxn.get_account(&id_name).unwrap();
1000 assert!(r1.is_none());
1001 let r2 = dbtxn.get_account(&id_spn).unwrap();
1002 assert!(r2.is_none());
1003 let r1 = dbtxn.get_account(&id_name2).unwrap();
1004 assert!(r1.is_some());
1005 let r2 = dbtxn.get_account(&id_spn2).unwrap();
1006 assert!(r2.is_some());
1007 let r3 = dbtxn.get_account(&id_uuid).unwrap();
1008 assert!(r3.is_some());
1009 let r4 = dbtxn.get_account(&id_gid).unwrap();
1010 assert!(r4.is_some());
1011
1012 assert!(dbtxn.clear().is_ok());
1014
1015 let r1 = dbtxn.get_account(&id_name2).unwrap();
1017 assert!(r1.is_none());
1018 let r2 = dbtxn.get_account(&id_spn2).unwrap();
1019 assert!(r2.is_none());
1020 let r3 = dbtxn.get_account(&id_uuid).unwrap();
1021 assert!(r3.is_none());
1022 let r4 = dbtxn.get_account(&id_gid).unwrap();
1023 assert!(r4.is_none());
1024
1025 assert!(dbtxn.commit().is_ok());
1026 }
1027
1028 #[tokio::test]
1029 async fn test_cache_db_group_basic() {
1030 sketching::test_init();
1031 let db = Db::new("").expect("failed to create.");
1032 let mut dbtxn = db.write().await;
1033 assert!(dbtxn.migrate().is_ok());
1034
1035 let mut gt1 = GroupToken {
1036 provider: ProviderOrigin::System,
1037 name: "testgroup".to_string(),
1038 spn: "testgroup@example.com".to_string(),
1039 gidnumber: 2000,
1040 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1041 extra_keys: Default::default(),
1042 };
1043
1044 let id_name = Id::Name("testgroup".to_string());
1045 let id_name2 = Id::Name("testgroup2".to_string());
1046 let id_spn = Id::Name("testgroup@example.com".to_string());
1047 let id_spn2 = Id::Name("testgroup2@example.com".to_string());
1048 let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1049 let id_gid = Id::Gid(2000);
1050
1051 let r1 = dbtxn.get_group(&id_name).unwrap();
1053 assert!(r1.is_none());
1054 let r2 = dbtxn.get_group(&id_spn).unwrap();
1055 assert!(r2.is_none());
1056 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1057 assert!(r3.is_none());
1058 let r4 = dbtxn.get_group(&id_gid).unwrap();
1059 assert!(r4.is_none());
1060
1061 dbtxn.update_group(>1, 0).unwrap();
1063 let r1 = dbtxn.get_group(&id_name).unwrap();
1064 assert!(r1.is_some());
1065 let r2 = dbtxn.get_group(&id_spn).unwrap();
1066 assert!(r2.is_some());
1067 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1068 assert!(r3.is_some());
1069 let r4 = dbtxn.get_group(&id_gid).unwrap();
1070 assert!(r4.is_some());
1071
1072 gt1.name = "testgroup2".to_string();
1074 gt1.spn = "testgroup2@example.com".to_string();
1075 dbtxn.update_group(>1, 0).unwrap();
1076 let r1 = dbtxn.get_group(&id_name).unwrap();
1077 assert!(r1.is_none());
1078 let r2 = dbtxn.get_group(&id_spn).unwrap();
1079 assert!(r2.is_none());
1080 let r1 = dbtxn.get_group(&id_name2).unwrap();
1081 assert!(r1.is_some());
1082 let r2 = dbtxn.get_group(&id_spn2).unwrap();
1083 assert!(r2.is_some());
1084 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1085 assert!(r3.is_some());
1086 let r4 = dbtxn.get_group(&id_gid).unwrap();
1087 assert!(r4.is_some());
1088
1089 assert!(dbtxn.clear().is_ok());
1091
1092 let r1 = dbtxn.get_group(&id_name2).unwrap();
1094 assert!(r1.is_none());
1095 let r2 = dbtxn.get_group(&id_spn2).unwrap();
1096 assert!(r2.is_none());
1097 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1098 assert!(r3.is_none());
1099 let r4 = dbtxn.get_group(&id_gid).unwrap();
1100 assert!(r4.is_none());
1101
1102 assert!(dbtxn.commit().is_ok());
1103 }
1104
1105 #[tokio::test]
1106 async fn test_cache_db_account_group_update() {
1107 sketching::test_init();
1108 let db = Db::new("").expect("failed to create.");
1109 let mut dbtxn = db.write().await;
1110 assert!(dbtxn.migrate().is_ok());
1111
1112 let gt1 = GroupToken {
1113 provider: ProviderOrigin::System,
1114 name: "testuser".to_string(),
1115 spn: "testuser@example.com".to_string(),
1116 gidnumber: 2000,
1117 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1118 extra_keys: Default::default(),
1119 };
1120
1121 let gt2 = GroupToken {
1122 provider: ProviderOrigin::System,
1123 name: "testgroup".to_string(),
1124 spn: "testgroup@example.com".to_string(),
1125 gidnumber: 2001,
1126 uuid: uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"),
1127 extra_keys: Default::default(),
1128 };
1129
1130 let mut ut1 = UserToken {
1131 provider: ProviderOrigin::System,
1132 name: "testuser".to_string(),
1133 spn: "testuser@example.com".to_string(),
1134 displayname: "Test User".to_string(),
1135 gidnumber: 2000,
1136 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1137 shell: None,
1138 groups: vec![gt1.clone(), gt2],
1139 sshkeys: vec!["key-a".to_string()],
1140 valid: true,
1141 extra_keys: Default::default(),
1142 };
1143
1144 ut1.groups.iter().for_each(|g| {
1146 dbtxn.update_group(g, 0).unwrap();
1147 });
1148
1149 dbtxn.update_account(&ut1, 0).unwrap();
1151
1152 let m1 = dbtxn
1154 .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1155 .unwrap();
1156 let m2 = dbtxn
1157 .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1158 .unwrap();
1159 assert_eq!(m1[0].name, "testuser");
1160 assert_eq!(m2[0].name, "testuser");
1161
1162 ut1.groups = vec![gt1];
1164 dbtxn.update_account(&ut1, 0).unwrap();
1165
1166 let m1 = dbtxn
1168 .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1169 .unwrap();
1170 let m2 = dbtxn
1171 .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1172 .unwrap();
1173 assert_eq!(m1[0].name, "testuser");
1174 assert!(m2.is_empty());
1175
1176 assert!(dbtxn.commit().is_ok());
1177 }
1178
1179 #[tokio::test]
1180 async fn test_cache_db_group_rename_duplicate() {
1181 sketching::test_init();
1182 let db = Db::new("").expect("failed to create.");
1183 let mut dbtxn = db.write().await;
1184 assert!(dbtxn.migrate().is_ok());
1185
1186 let mut gt1 = GroupToken {
1187 provider: ProviderOrigin::System,
1188 name: "testgroup".to_string(),
1189 spn: "testgroup@example.com".to_string(),
1190 gidnumber: 2000,
1191 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1192 extra_keys: Default::default(),
1193 };
1194
1195 let gt2 = GroupToken {
1196 provider: ProviderOrigin::System,
1197 name: "testgroup".to_string(),
1198 spn: "testgroup@example.com".to_string(),
1199 gidnumber: 2001,
1200 uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1201 extra_keys: Default::default(),
1202 };
1203
1204 let id_name = Id::Name("testgroup".to_string());
1205 let id_name2 = Id::Name("testgroup2".to_string());
1206
1207 let r1 = dbtxn.get_group(&id_name).unwrap();
1209 assert!(r1.is_none());
1210
1211 dbtxn.update_group(>1, 0).unwrap();
1213 let r0 = dbtxn.get_group(&id_name).unwrap();
1214 assert_eq!(
1215 r0.unwrap().0.uuid,
1216 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1217 );
1218
1219 gt1.name = "testgroup2".to_string();
1221 gt1.spn = "testgroup2@example.com".to_string();
1222 dbtxn.update_group(>2, 0).unwrap();
1224 let r2 = dbtxn.get_group(&id_name).unwrap();
1225 assert_eq!(
1226 r2.unwrap().0.uuid,
1227 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1228 );
1229 let r3 = dbtxn.get_group(&id_name2).unwrap();
1230 assert!(r3.is_none());
1231
1232 dbtxn.update_group(>1, 0).unwrap();
1234
1235 let r4 = dbtxn.get_group(&id_name).unwrap();
1237 assert_eq!(
1238 r4.unwrap().0.uuid,
1239 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1240 );
1241 let r5 = dbtxn.get_group(&id_name2).unwrap();
1242 assert_eq!(
1243 r5.unwrap().0.uuid,
1244 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1245 );
1246
1247 assert!(dbtxn.commit().is_ok());
1248 }
1249
1250 #[tokio::test]
1251 async fn test_cache_db_account_rename_duplicate() {
1252 sketching::test_init();
1253 let db = Db::new("").expect("failed to create.");
1254 let mut dbtxn = db.write().await;
1255 assert!(dbtxn.migrate().is_ok());
1256
1257 let mut ut1 = UserToken {
1258 provider: ProviderOrigin::System,
1259 name: "testuser".to_string(),
1260 spn: "testuser@example.com".to_string(),
1261 displayname: "Test User".to_string(),
1262 gidnumber: 2000,
1263 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1264 shell: None,
1265 groups: Vec::new(),
1266 sshkeys: vec!["key-a".to_string()],
1267 valid: true,
1268 extra_keys: Default::default(),
1269 };
1270
1271 let ut2 = UserToken {
1272 provider: ProviderOrigin::System,
1273 name: "testuser".to_string(),
1274 spn: "testuser@example.com".to_string(),
1275 displayname: "Test User".to_string(),
1276 gidnumber: 2001,
1277 uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1278 shell: None,
1279 groups: Vec::new(),
1280 sshkeys: vec!["key-a".to_string()],
1281 valid: true,
1282 extra_keys: Default::default(),
1283 };
1284
1285 let id_name = Id::Name("testuser".to_string());
1286 let id_name2 = Id::Name("testuser2".to_string());
1287
1288 let r1 = dbtxn.get_account(&id_name).unwrap();
1290 assert!(r1.is_none());
1291
1292 dbtxn.update_account(&ut1, 0).unwrap();
1294 let r0 = dbtxn.get_account(&id_name).unwrap();
1295 assert_eq!(
1296 r0.unwrap().0.uuid,
1297 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1298 );
1299
1300 ut1.name = "testuser2".to_string();
1302 ut1.spn = "testuser2@example.com".to_string();
1303 dbtxn.update_account(&ut2, 0).unwrap();
1305 let r2 = dbtxn.get_account(&id_name).unwrap();
1306 assert_eq!(
1307 r2.unwrap().0.uuid,
1308 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1309 );
1310 let r3 = dbtxn.get_account(&id_name2).unwrap();
1311 assert!(r3.is_none());
1312
1313 dbtxn.update_account(&ut1, 0).unwrap();
1315
1316 let r4 = dbtxn.get_account(&id_name).unwrap();
1318 assert_eq!(
1319 r4.unwrap().0.uuid,
1320 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1321 );
1322 let r5 = dbtxn.get_account(&id_name2).unwrap();
1323 assert_eq!(
1324 r5.unwrap().0.uuid,
1325 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1326 );
1327
1328 assert!(dbtxn.commit().is_ok());
1329 }
1330}