1use std::convert::TryFrom;
2use std::fmt;
3
4use crate::idprovider::interface::{GroupToken, Id, UserToken};
5use async_trait::async_trait;
6use libc::umask;
7use rusqlite::{Connection, OptionalExtension};
8use tokio::sync::{Mutex, MutexGuard};
9use uuid::Uuid;
10
11use serde::{de::DeserializeOwned, Serialize};
12
13use kanidm_hsm_crypto::{LoadableHmacKey, LoadableMachineKey};
14
15const DBV_MAIN: &str = "main";
16
17#[async_trait]
18pub trait Cache {
19 type Txn<'db>
20 where
21 Self: 'db;
22
23 async fn write<'db>(&'db self) -> Self::Txn<'db>;
24}
25
26#[async_trait]
27pub trait KeyStore {
28 type Txn<'db>
29 where
30 Self: 'db;
31
32 async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
33}
34
35#[derive(Debug)]
36pub enum CacheError {
37 Cryptography,
38 SerdeJson,
39 Parse,
40 Sqlite,
41 TooManyResults,
42 TransactionInvalidState,
43 Tpm,
44}
45
46pub struct Db {
47 conn: Mutex<Connection>,
48}
49
50pub struct DbTxn<'a> {
51 conn: MutexGuard<'a, Connection>,
52 committed: bool,
53}
54
55pub struct KeyStoreTxn<'a, 'b> {
56 db: &'b mut DbTxn<'a>,
57}
58
59impl<'a, 'b> From<&'b mut DbTxn<'a>> for KeyStoreTxn<'a, 'b> {
60 fn from(db: &'b mut DbTxn<'a>) -> Self {
61 Self { db }
62 }
63}
64
65#[derive(Debug)]
66pub enum DbError {
68 Sqlite,
69 Tpm,
70}
71
72impl Db {
73 pub fn new(path: &str) -> Result<Self, DbError> {
74 let before = unsafe { umask(0o0027) };
75 let conn = Connection::open(path).map_err(|e| {
76 error!(err = ?e, "rusqulite error");
77 DbError::Sqlite
78 })?;
79 let _ = unsafe { umask(before) };
80
81 Ok(Db {
82 conn: Mutex::new(conn),
83 })
84 }
85}
86
87#[async_trait]
88impl Cache for Db {
89 type Txn<'db> = DbTxn<'db>;
90
91 #[allow(clippy::expect_used)]
92 async fn write<'db>(&'db self) -> Self::Txn<'db> {
93 let conn = self.conn.lock().await;
94 DbTxn::new(conn)
95 }
96}
97
98impl fmt::Debug for Db {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 write!(f, "Db {{}}")
101 }
102}
103
104impl<'a> DbTxn<'a> {
105 fn new(conn: MutexGuard<'a, Connection>) -> Self {
106 #[allow(clippy::expect_used)]
109 conn.execute("BEGIN TRANSACTION", [])
110 .expect("Unable to begin transaction!");
111 DbTxn {
112 committed: false,
113 conn,
114 }
115 }
116
117 fn sqlite_error(&self, msg: &str, error: &rusqlite::Error) -> CacheError {
119 error!(
120 "sqlite {} error: {:?} db_path={:?}",
121 msg,
122 error,
123 &self.conn.path()
124 );
125 CacheError::Sqlite
126 }
127
128 fn sqlite_transaction_error(
130 &self,
131 error: &rusqlite::Error,
132 _stmt: &rusqlite::Statement,
133 ) -> CacheError {
134 error!(
135 "sqlite transaction error={:?} db_path={:?}",
136 error,
137 &self.conn.path(),
138 );
139 CacheError::Sqlite
141 }
142
143 fn get_db_version(&self, key: &str) -> i64 {
144 self.conn
145 .query_row(
146 "SELECT version FROM db_version_t WHERE id = :id",
147 &[(":id", key)],
148 |row| row.get(0),
149 )
150 .unwrap_or({
151 0
153 })
154 }
155
156 fn set_db_version(&self, key: &str, v: i64) -> Result<(), CacheError> {
157 self.conn
158 .execute(
159 "INSERT OR REPLACE INTO db_version_t (id, version) VALUES(:id, :dbv)",
160 named_params! {
161 ":id": &key,
162 ":dbv": v,
163 },
164 )
165 .map(|_| ())
166 .map_err(|e| self.sqlite_error("set db_version_t", &e))
167 }
168
169 fn get_account_data_name(
170 &mut self,
171 account_id: &str,
172 ) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
173 let mut stmt = self.conn
174 .prepare(
175 "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
176 )
177 .map_err(|e| {
178 self.sqlite_error("select prepare", &e)
179 })?;
180
181 let data_iter = stmt
183 .query_map([account_id], |row| Ok((row.get(0)?, row.get(1)?)))
184 .map_err(|e| self.sqlite_error("query_map failure", &e))?;
185 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
186 .map(|v| v.map_err(|e| self.sqlite_error("map failure", &e)))
187 .collect();
188 data
189 }
190
191 fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
192 let mut stmt = self
193 .conn
194 .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
195 .map_err(|e| self.sqlite_error("select prepare", &e))?;
196
197 let data_iter = stmt
199 .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
200 .map_err(|e| self.sqlite_error("query_map", &e))?;
201 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
202 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
203 .collect();
204 data
205 }
206
207 fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
208 let mut stmt = self.conn
209 .prepare(
210 "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
211 )
212 .map_err(|e| {
213 self.sqlite_error("select prepare", &e)
214 })?;
215
216 let data_iter = stmt
218 .query_map([grp_id], |row| Ok((row.get(0)?, row.get(1)?)))
219 .map_err(|e| self.sqlite_error("query_map", &e))?;
220 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
221 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
222 .collect();
223 data
224 }
225
226 fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
227 let mut stmt = self
228 .conn
229 .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
230 .map_err(|e| self.sqlite_error("select prepare", &e))?;
231
232 let data_iter = stmt
234 .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
235 .map_err(|e| self.sqlite_error("query_map", &e))?;
236 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
237 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
238 .collect();
239 data
240 }
241}
242
243impl KeyStoreTxn<'_, '_> {
244 pub fn get_tagged_hsm_key<K: DeserializeOwned>(
245 &mut self,
246 tag: &str,
247 ) -> Result<Option<K>, CacheError> {
248 self.db.get_tagged_hsm_key(tag)
249 }
250
251 pub fn insert_tagged_hsm_key<K: Serialize>(
252 &mut self,
253 tag: &str,
254 key: &K,
255 ) -> Result<(), CacheError> {
256 self.db.insert_tagged_hsm_key(tag, key)
257 }
258
259 pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
260 self.db.delete_tagged_hsm_key(tag)
261 }
262}
263
264impl DbTxn<'_> {
265 fn get_tagged_hsm_key<K: DeserializeOwned>(
266 &mut self,
267 tag: &str,
268 ) -> Result<Option<K>, CacheError> {
269 let mut stmt = self
270 .conn
271 .prepare("SELECT value FROM hsm_data_t WHERE key = :key")
272 .map_err(|e| self.sqlite_error("select prepare", &e))?;
273
274 let data: Option<Vec<u8>> = stmt
275 .query_row(
276 named_params! {
277 ":key": tag
278 },
279 |row| row.get(0),
280 )
281 .optional()
282 .map_err(|e| self.sqlite_error("query_row", &e))?;
283
284 match data {
285 Some(d) => Ok(serde_json::from_slice(d.as_slice())
286 .map_err(|e| {
287 error!("json error -> {:?}", e);
288 })
289 .ok()),
290 None => Ok(None),
291 }
292 }
293
294 fn insert_tagged_hsm_key<K: Serialize>(
295 &mut self,
296 tag: &str,
297 key: &K,
298 ) -> Result<(), CacheError> {
299 let data = serde_json::to_vec(key).map_err(|e| {
300 error!("json error -> {:?}", e);
301 CacheError::SerdeJson
302 })?;
303
304 let mut stmt = self
305 .conn
306 .prepare("INSERT OR REPLACE INTO hsm_data_t (key, value) VALUES (:key, :value)")
307 .map_err(|e| self.sqlite_error("prepare", &e))?;
308
309 stmt.execute(named_params! {
310 ":key": tag,
311 ":value": &data,
312 })
313 .map(|r| {
314 trace!("insert -> {:?}", r);
315 })
316 .map_err(|e| self.sqlite_error("execute", &e))
317 }
318
319 fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
320 self.conn
321 .execute(
322 "DELETE FROM hsm_data_t where key = :key",
323 named_params! {
324 ":key": tag,
325 },
326 )
327 .map(|_| ())
328 .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))
329 }
330}
331
332impl DbTxn<'_> {
333 pub fn migrate(&mut self) -> Result<(), CacheError> {
334 self.conn.set_prepared_statement_cache_capacity(16);
335 self.conn
336 .prepare("PRAGMA journal_mode=WAL;")
337 .and_then(|mut wal_stmt| wal_stmt.query([]).map(|_| ()))
338 .map_err(|e| self.sqlite_error("account_t create", &e))?;
339
340 self.conn
342 .execute(
343 "CREATE TABLE IF NOT EXISTS db_version_t (
344 id TEXT PRIMARY KEY,
345 version INTEGER
346 )",
347 [],
348 )
349 .map_err(|e| self.sqlite_error("db_version_t create", &e))?;
350
351 let db_version = self.get_db_version(DBV_MAIN);
352
353 if db_version < 1 {
354 self.conn
358 .execute(
359 "CREATE TABLE IF NOT EXISTS account_t (
360 uuid TEXT PRIMARY KEY,
361 name TEXT NOT NULL UNIQUE,
362 spn TEXT NOT NULL UNIQUE,
363 gidnumber INTEGER NOT NULL UNIQUE,
364 password BLOB,
365 token BLOB NOT NULL,
366 expiry NUMERIC NOT NULL
367 )
368 ",
369 [],
370 )
371 .map_err(|e| self.sqlite_error("account_t create", &e))?;
372
373 self.conn
374 .execute(
375 "CREATE TABLE IF NOT EXISTS group_t (
376 uuid TEXT PRIMARY KEY,
377 name TEXT NOT NULL UNIQUE,
378 spn TEXT NOT NULL UNIQUE,
379 gidnumber INTEGER NOT NULL UNIQUE,
380 token BLOB NOT NULL,
381 expiry NUMERIC NOT NULL
382 )
383 ",
384 [],
385 )
386 .map_err(|e| self.sqlite_error("group_t create", &e))?;
387
388 self.conn
395 .execute(
396 "CREATE TABLE IF NOT EXISTS memberof_t (
397 g_uuid TEXT,
398 a_uuid TEXT,
399 FOREIGN KEY(g_uuid) REFERENCES group_t(uuid) DEFERRABLE INITIALLY DEFERRED,
400 FOREIGN KEY(a_uuid) REFERENCES account_t(uuid) ON DELETE CASCADE
401 )
402 ",
403 [],
404 )
405 .map_err(|e| self.sqlite_error("memberof_t create error", &e))?;
406
407 self.conn
410 .execute(
411 "CREATE TABLE IF NOT EXISTS hsm_int_t (
412 key TEXT PRIMARY KEY,
413 value BLOB NOT NULL
414 )
415 ",
416 [],
417 )
418 .map_err(|e| self.sqlite_error("hsm_int_t create error", &e))?;
419
420 self.conn
421 .execute(
422 "CREATE TABLE IF NOT EXISTS hsm_data_t (
423 key TEXT PRIMARY KEY,
424 value BLOB NOT NULL
425 )
426 ",
427 [],
428 )
429 .map_err(|e| self.sqlite_error("hsm_data_t create error", &e))?;
430
431 self.clear_hsm()?;
433 }
434
435 self.set_db_version(DBV_MAIN, 1)?;
436
437 Ok(())
438 }
439
440 pub fn commit(mut self) -> Result<(), CacheError> {
441 if self.committed {
442 error!("Invalid state, SQL transaction was already committed!");
443 return Err(CacheError::TransactionInvalidState);
444 }
445 self.committed = true;
446
447 self.conn
448 .execute("COMMIT TRANSACTION", [])
449 .map(|_| ())
450 .map_err(|e| self.sqlite_error("commit", &e))
451 }
452
453 pub fn invalidate(&mut self) -> Result<(), CacheError> {
454 self.conn
455 .execute("UPDATE group_t SET expiry = 0", [])
456 .map_err(|e| self.sqlite_error("update group_t", &e))?;
457
458 self.conn
459 .execute("UPDATE account_t SET expiry = 0", [])
460 .map_err(|e| self.sqlite_error("update account_t", &e))?;
461
462 Ok(())
463 }
464
465 pub fn clear(&mut self) -> Result<(), CacheError> {
466 self.conn
467 .execute("DELETE FROM memberof_t", [])
468 .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
469
470 self.conn
471 .execute("DELETE FROM group_t", [])
472 .map_err(|e| self.sqlite_error("delete group_t", &e))?;
473
474 self.conn
475 .execute("DELETE FROM account_t", [])
476 .map_err(|e| self.sqlite_error("delete group_t", &e))?;
477
478 Ok(())
479 }
480
481 pub fn clear_hsm(&mut self) -> Result<(), CacheError> {
482 self.clear()?;
483
484 self.conn
485 .execute("DELETE FROM hsm_int_t", [])
486 .map_err(|e| self.sqlite_error("delete hsm_int_t", &e))?;
487
488 self.conn
489 .execute("DELETE FROM hsm_data_t", [])
490 .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))?;
491
492 Ok(())
493 }
494
495 pub fn get_hsm_machine_key(&mut self) -> Result<Option<LoadableMachineKey>, CacheError> {
496 let mut stmt = self
497 .conn
498 .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
499 .map_err(|e| self.sqlite_error("select prepare", &e))?;
500
501 let data: Option<Vec<u8>> = stmt
502 .query_row([], |row| row.get(0))
503 .optional()
504 .map_err(|e| self.sqlite_error("query_row", &e))?;
505
506 match data {
507 Some(d) => Ok(serde_json::from_slice(d.as_slice())
508 .map_err(|e| {
509 error!("json error -> {:?}", e);
510 })
511 .ok()),
512 None => Ok(None),
513 }
514 }
515
516 pub fn insert_hsm_machine_key(
517 &mut self,
518 machine_key: &LoadableMachineKey,
519 ) -> Result<(), CacheError> {
520 let data = serde_json::to_vec(machine_key).map_err(|e| {
521 error!("insert_hsm_machine_key json error -> {:?}", e);
522 CacheError::SerdeJson
523 })?;
524
525 let mut stmt = self
526 .conn
527 .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
528 .map_err(|e| self.sqlite_error("prepare", &e))?;
529
530 stmt.execute(named_params! {
531 ":key": "mk",
532 ":value": &data,
533 })
534 .map(|r| {
535 trace!("insert -> {:?}", r);
536 })
537 .map_err(|e| self.sqlite_error("execute", &e))
538 }
539
540 pub fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacKey>, CacheError> {
541 let mut stmt = self
542 .conn
543 .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
544 .map_err(|e| self.sqlite_error("select prepare", &e))?;
545
546 let data: Option<Vec<u8>> = stmt
547 .query_row([], |row| row.get(0))
548 .optional()
549 .map_err(|e| self.sqlite_error("query_row", &e))?;
550
551 match data {
552 Some(d) => Ok(serde_json::from_slice(d.as_slice())
553 .map_err(|e| {
554 error!("json error -> {:?}", e);
555 })
556 .ok()),
557 None => Ok(None),
558 }
559 }
560
561 pub fn insert_hsm_hmac_key(&mut self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError> {
562 let data = serde_json::to_vec(hmac_key).map_err(|e| {
563 error!("insert_hsm_hmac_key json error -> {:?}", e);
564 CacheError::SerdeJson
565 })?;
566
567 let mut stmt = self
568 .conn
569 .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
570 .map_err(|e| self.sqlite_error("prepare", &e))?;
571
572 stmt.execute(named_params! {
573 ":key": "hmac",
574 ":value": &data,
575 })
576 .map(|r| {
577 trace!("insert -> {:?}", r);
578 })
579 .map_err(|e| self.sqlite_error("execute", &e))
580 }
581
582 pub fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
583 let data = match account_id {
584 Id::Name(n) => self.get_account_data_name(n.as_str()),
585 Id::Gid(g) => self.get_account_data_gid(*g),
586 }?;
587
588 if data.len() >= 2 {
590 error!("invalid db state, multiple entries matched query?");
591 return Err(CacheError::TooManyResults);
592 }
593
594 if let Some((token, expiry)) = data.first() {
595 match serde_json::from_slice(token.as_slice()) {
599 Ok(t) => {
600 let e = u64::try_from(*expiry).map_err(|e| {
601 error!("u64 convert error -> {:?}", e);
602 CacheError::Parse
603 })?;
604 Ok(Some((t, e)))
605 }
606 Err(e) => {
607 warn!("recoverable - json error -> {:?}", e);
608 Ok(None)
609 }
610 }
611 } else {
612 Ok(None)
613 }
614 }
615
616 pub fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
617 let mut stmt = self
618 .conn
619 .prepare("SELECT token FROM account_t")
620 .map_err(|e| self.sqlite_error("select prepare", &e))?;
621
622 let data_iter = stmt
623 .query_map([], |row| row.get(0))
624 .map_err(|e| self.sqlite_error("query_map", &e))?;
625 let data: Result<Vec<Vec<u8>>, _> = data_iter
626 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
627 .collect();
628
629 let data = data?;
630
631 Ok(data
632 .iter()
633 .filter_map(|token| {
635 serde_json::from_slice(token.as_slice())
637 .map_err(|e| {
638 warn!("get_accounts json error -> {:?}", e);
639 })
640 .ok()
641 })
642 .collect())
643 }
644
645 pub fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
646 let data = serde_json::to_vec(account).map_err(|e| {
647 error!("update_account json error -> {:?}", e);
648 CacheError::SerdeJson
649 })?;
650 let expire = i64::try_from(expire).map_err(|e| {
651 error!("update_account i64 conversion error -> {:?}", e);
652 CacheError::Parse
653 })?;
654
655 let account_uuid = account.uuid.as_hyphenated().to_string();
659
660 self.conn.execute("DELETE FROM account_t WHERE NOT uuid = :uuid AND (name = :name OR spn = :spn OR gidnumber = :gidnumber)",
662 named_params!{
663 ":uuid": &account_uuid,
664 ":name": &account.name,
665 ":spn": &account.spn,
666 ":gidnumber": &account.gidnumber,
667 }
668 )
669 .map_err(|e| {
670 self.sqlite_error("delete account_t duplicate", &e)
671 })
672 .map(|_| ())?;
673
674 let updated = self.conn.execute(
675 "UPDATE account_t SET name=:name, spn=:spn, gidnumber=:gidnumber, token=:token, expiry=:expiry WHERE uuid = :uuid",
676 named_params!{
677 ":uuid": &account_uuid,
678 ":name": &account.name,
679 ":spn": &account.spn,
680 ":gidnumber": &account.gidnumber,
681 ":token": &data,
682 ":expiry": &expire,
683 }
684 )
685 .map_err(|e| {
686 self.sqlite_error("delete account_t duplicate", &e)
687 })?;
688
689 if updated == 0 {
690 let mut stmt = self.conn
691 .prepare("INSERT INTO account_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry) ON CONFLICT(uuid) DO UPDATE SET name=excluded.name, spn=excluded.name, gidnumber=excluded.gidnumber, token=excluded.token, expiry=excluded.expiry")
692 .map_err(|e| {
693 self.sqlite_error("prepare", &e)
694 })?;
695
696 stmt.execute(named_params! {
697 ":uuid": &account_uuid,
698 ":name": &account.name,
699 ":spn": &account.spn,
700 ":gidnumber": &account.gidnumber,
701 ":token": &data,
702 ":expiry": &expire,
703 })
704 .map(|r| {
705 trace!("insert -> {:?}", r);
706 })
707 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
708 }
709
710 let mut stmt = self
714 .conn
715 .prepare("DELETE FROM memberof_t WHERE a_uuid = :a_uuid")
716 .map_err(|e| self.sqlite_error("prepare", &e))?;
717
718 stmt.execute([&account_uuid])
719 .map(|r| {
720 trace!("delete memberships -> {:?}", r);
721 })
722 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
723
724 let mut stmt = self
725 .conn
726 .prepare("INSERT INTO memberof_t (a_uuid, g_uuid) VALUES (:a_uuid, :g_uuid)")
727 .map_err(|e| self.sqlite_error("prepare", &e))?;
728 account.groups.iter().try_for_each(|g| {
730 stmt.execute(named_params! {
731 ":a_uuid": &account_uuid,
732 ":g_uuid": &g.uuid.as_hyphenated().to_string(),
733 })
734 .map(|r| {
735 trace!("insert membership -> {:?}", r);
736 })
737 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))
738 })
739 }
740
741 pub fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
742 let account_uuid = a_uuid.as_hyphenated().to_string();
743
744 self.conn
745 .execute(
746 "DELETE FROM memberof_t WHERE a_uuid = :a_uuid",
747 params![&account_uuid],
748 )
749 .map(|_| ())
750 .map_err(|e| self.sqlite_error("account_t memberof_t cascade delete", &e))?;
751
752 self.conn
753 .execute(
754 "DELETE FROM account_t WHERE uuid = :a_uuid",
755 params![&account_uuid],
756 )
757 .map(|_| ())
758 .map_err(|e| self.sqlite_error("account_t delete", &e))
759 }
760
761 pub fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
762 let data = match grp_id {
763 Id::Name(n) => self.get_group_data_name(n.as_str()),
764 Id::Gid(g) => self.get_group_data_gid(*g),
765 }?;
766
767 if data.len() >= 2 {
769 error!("invalid db state, multiple entries matched query?");
770 return Err(CacheError::TooManyResults);
771 }
772
773 if let Some((token, expiry)) = data.first() {
774 match serde_json::from_slice(token.as_slice()) {
778 Ok(t) => {
779 let e = u64::try_from(*expiry).map_err(|e| {
780 error!("u64 convert error -> {:?}", e);
781 CacheError::Parse
782 })?;
783 Ok(Some((t, e)))
784 }
785 Err(e) => {
786 warn!("recoverable - json error -> {:?}", e);
787 Ok(None)
788 }
789 }
790 } else {
791 Ok(None)
792 }
793 }
794
795 pub fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
796 let mut stmt = self
797 .conn
798 .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
799 .map_err(|e| {
800 self.sqlite_error("select prepare", &e)
801 })?;
802
803 let data_iter = stmt
804 .query_map([g_uuid.as_hyphenated().to_string()], |row| row.get(0))
805 .map_err(|e| self.sqlite_error("query_map", &e))?;
806 let data: Result<Vec<Vec<u8>>, _> = data_iter
807 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
808 .collect();
809
810 let data = data?;
811
812 data.iter()
813 .map(|token| {
814 serde_json::from_slice(token.as_slice()).map_err(|e| {
817 error!("json error -> {:?}", e);
818 CacheError::SerdeJson
819 })
820 })
821 .collect()
822 }
823
824 pub fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
825 let mut stmt = self
826 .conn
827 .prepare("SELECT token FROM group_t")
828 .map_err(|e| self.sqlite_error("select prepare", &e))?;
829
830 let data_iter = stmt
831 .query_map([], |row| row.get(0))
832 .map_err(|e| self.sqlite_error("query_map", &e))?;
833 let data: Result<Vec<Vec<u8>>, _> = data_iter
834 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
835 .collect();
836
837 let data = data?;
838
839 Ok(data
840 .iter()
841 .filter_map(|token| {
842 serde_json::from_slice(token.as_slice())
845 .map_err(|e| {
846 error!("json error -> {:?}", e);
847 })
848 .ok()
849 })
850 .collect())
851 }
852
853 pub fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
854 let data = serde_json::to_vec(grp).map_err(|e| {
855 error!("json error -> {:?}", e);
856 CacheError::SerdeJson
857 })?;
858 let expire = i64::try_from(expire).map_err(|e| {
859 error!("i64 convert error -> {:?}", e);
860 CacheError::Parse
861 })?;
862
863 let mut stmt = self.conn
864 .prepare("INSERT OR REPLACE INTO group_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry)")
865 .map_err(|e| {
866 self.sqlite_error("prepare", &e)
867 })?;
868
869 stmt.execute(named_params! {
871 ":uuid": &grp.uuid.as_hyphenated().to_string(),
872 ":name": &grp.name,
873 ":spn": &grp.spn,
874 ":gidnumber": &grp.gidnumber,
875 ":token": &data,
876 ":expiry": &expire,
877 })
878 .map(|r| {
879 trace!("insert -> {:?}", r);
880 })
881 .map_err(|e| self.sqlite_error("execute", &e))
882 }
883
884 pub fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
885 let group_uuid = g_uuid.as_hyphenated().to_string();
886 self.conn
887 .execute(
888 "DELETE FROM memberof_t WHERE g_uuid = :g_uuid",
889 [&group_uuid],
890 )
891 .map(|_| ())
892 .map_err(|e| self.sqlite_error("group_t memberof_t cascade delete", &e))?;
893 self.conn
894 .execute("DELETE FROM group_t WHERE uuid = :g_uuid", [&group_uuid])
895 .map(|_| ())
896 .map_err(|e| self.sqlite_error("group_t delete", &e))
897 }
898}
899
900impl fmt::Debug for DbTxn<'_> {
901 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
902 write!(f, "DbTxn {{}}")
903 }
904}
905
906impl Drop for DbTxn<'_> {
907 fn drop(&mut self) {
909 if !self.committed {
910 #[allow(clippy::expect_used)]
912 self.conn
913 .execute("ROLLBACK TRANSACTION", [])
914 .expect("Unable to rollback transaction! Can not proceed!!!");
915 }
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::{Cache, Db};
922 use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
923
924 #[tokio::test]
925 async fn test_cache_db_account_basic() {
926 sketching::test_init();
927 let db = Db::new("").expect("failed to create.");
928 let mut dbtxn = db.write().await;
929 assert!(dbtxn.migrate().is_ok());
930
931 let mut ut1 = UserToken {
932 provider: ProviderOrigin::System,
933 name: "testuser".to_string(),
934 spn: "testuser@example.com".to_string(),
935 displayname: "Test User".to_string(),
936 gidnumber: 2000,
937 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
938 shell: None,
939 groups: Vec::new(),
940 sshkeys: vec!["key-a".to_string()],
941 valid: true,
942 extra_keys: Default::default(),
943 };
944
945 let id_name = Id::Name("testuser".to_string());
946 let id_name2 = Id::Name("testuser2".to_string());
947 let id_spn = Id::Name("testuser@example.com".to_string());
948 let id_spn2 = Id::Name("testuser2@example.com".to_string());
949 let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
950 let id_gid = Id::Gid(2000);
951
952 let r1 = dbtxn.get_account(&id_name).unwrap();
954 assert!(r1.is_none());
955 let r2 = dbtxn.get_account(&id_spn).unwrap();
956 assert!(r2.is_none());
957 let r3 = dbtxn.get_account(&id_uuid).unwrap();
958 assert!(r3.is_none());
959 let r4 = dbtxn.get_account(&id_gid).unwrap();
960 assert!(r4.is_none());
961
962 dbtxn.update_account(&ut1, 0).unwrap();
964
965 let r1 = dbtxn.get_account(&id_name).unwrap();
967 assert!(r1.is_some());
968 let r2 = dbtxn.get_account(&id_spn).unwrap();
969 assert!(r2.is_some());
970 let r3 = dbtxn.get_account(&id_uuid).unwrap();
971 assert!(r3.is_some());
972 let r4 = dbtxn.get_account(&id_gid).unwrap();
973 assert!(r4.is_some());
974
975 ut1.name = "testuser2".to_string();
977 ut1.spn = "testuser2@example.com".to_string();
978 dbtxn.update_account(&ut1, 0).unwrap();
979
980 let r1 = dbtxn.get_account(&id_name).unwrap();
982 assert!(r1.is_none());
983 let r2 = dbtxn.get_account(&id_spn).unwrap();
984 assert!(r2.is_none());
985 let r1 = dbtxn.get_account(&id_name2).unwrap();
986 assert!(r1.is_some());
987 let r2 = dbtxn.get_account(&id_spn2).unwrap();
988 assert!(r2.is_some());
989 let r3 = dbtxn.get_account(&id_uuid).unwrap();
990 assert!(r3.is_some());
991 let r4 = dbtxn.get_account(&id_gid).unwrap();
992 assert!(r4.is_some());
993
994 assert!(dbtxn.clear().is_ok());
996
997 let r1 = dbtxn.get_account(&id_name2).unwrap();
999 assert!(r1.is_none());
1000 let r2 = dbtxn.get_account(&id_spn2).unwrap();
1001 assert!(r2.is_none());
1002 let r3 = dbtxn.get_account(&id_uuid).unwrap();
1003 assert!(r3.is_none());
1004 let r4 = dbtxn.get_account(&id_gid).unwrap();
1005 assert!(r4.is_none());
1006
1007 assert!(dbtxn.commit().is_ok());
1008 }
1009
1010 #[tokio::test]
1011 async fn test_cache_db_group_basic() {
1012 sketching::test_init();
1013 let db = Db::new("").expect("failed to create.");
1014 let mut dbtxn = db.write().await;
1015 assert!(dbtxn.migrate().is_ok());
1016
1017 let mut gt1 = GroupToken {
1018 provider: ProviderOrigin::System,
1019 name: "testgroup".to_string(),
1020 spn: "testgroup@example.com".to_string(),
1021 gidnumber: 2000,
1022 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1023 extra_keys: Default::default(),
1024 };
1025
1026 let id_name = Id::Name("testgroup".to_string());
1027 let id_name2 = Id::Name("testgroup2".to_string());
1028 let id_spn = Id::Name("testgroup@example.com".to_string());
1029 let id_spn2 = Id::Name("testgroup2@example.com".to_string());
1030 let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1031 let id_gid = Id::Gid(2000);
1032
1033 let r1 = dbtxn.get_group(&id_name).unwrap();
1035 assert!(r1.is_none());
1036 let r2 = dbtxn.get_group(&id_spn).unwrap();
1037 assert!(r2.is_none());
1038 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1039 assert!(r3.is_none());
1040 let r4 = dbtxn.get_group(&id_gid).unwrap();
1041 assert!(r4.is_none());
1042
1043 dbtxn.update_group(>1, 0).unwrap();
1045 let r1 = dbtxn.get_group(&id_name).unwrap();
1046 assert!(r1.is_some());
1047 let r2 = dbtxn.get_group(&id_spn).unwrap();
1048 assert!(r2.is_some());
1049 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1050 assert!(r3.is_some());
1051 let r4 = dbtxn.get_group(&id_gid).unwrap();
1052 assert!(r4.is_some());
1053
1054 gt1.name = "testgroup2".to_string();
1056 gt1.spn = "testgroup2@example.com".to_string();
1057 dbtxn.update_group(>1, 0).unwrap();
1058 let r1 = dbtxn.get_group(&id_name).unwrap();
1059 assert!(r1.is_none());
1060 let r2 = dbtxn.get_group(&id_spn).unwrap();
1061 assert!(r2.is_none());
1062 let r1 = dbtxn.get_group(&id_name2).unwrap();
1063 assert!(r1.is_some());
1064 let r2 = dbtxn.get_group(&id_spn2).unwrap();
1065 assert!(r2.is_some());
1066 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1067 assert!(r3.is_some());
1068 let r4 = dbtxn.get_group(&id_gid).unwrap();
1069 assert!(r4.is_some());
1070
1071 assert!(dbtxn.clear().is_ok());
1073
1074 let r1 = dbtxn.get_group(&id_name2).unwrap();
1076 assert!(r1.is_none());
1077 let r2 = dbtxn.get_group(&id_spn2).unwrap();
1078 assert!(r2.is_none());
1079 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1080 assert!(r3.is_none());
1081 let r4 = dbtxn.get_group(&id_gid).unwrap();
1082 assert!(r4.is_none());
1083
1084 assert!(dbtxn.commit().is_ok());
1085 }
1086
1087 #[tokio::test]
1088 async fn test_cache_db_account_group_update() {
1089 sketching::test_init();
1090 let db = Db::new("").expect("failed to create.");
1091 let mut dbtxn = db.write().await;
1092 assert!(dbtxn.migrate().is_ok());
1093
1094 let gt1 = GroupToken {
1095 provider: ProviderOrigin::System,
1096 name: "testuser".to_string(),
1097 spn: "testuser@example.com".to_string(),
1098 gidnumber: 2000,
1099 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1100 extra_keys: Default::default(),
1101 };
1102
1103 let gt2 = GroupToken {
1104 provider: ProviderOrigin::System,
1105 name: "testgroup".to_string(),
1106 spn: "testgroup@example.com".to_string(),
1107 gidnumber: 2001,
1108 uuid: uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"),
1109 extra_keys: Default::default(),
1110 };
1111
1112 let mut ut1 = UserToken {
1113 provider: ProviderOrigin::System,
1114 name: "testuser".to_string(),
1115 spn: "testuser@example.com".to_string(),
1116 displayname: "Test User".to_string(),
1117 gidnumber: 2000,
1118 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1119 shell: None,
1120 groups: vec![gt1.clone(), gt2],
1121 sshkeys: vec!["key-a".to_string()],
1122 valid: true,
1123 extra_keys: Default::default(),
1124 };
1125
1126 ut1.groups.iter().for_each(|g| {
1128 dbtxn.update_group(g, 0).unwrap();
1129 });
1130
1131 dbtxn.update_account(&ut1, 0).unwrap();
1133
1134 let m1 = dbtxn
1136 .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1137 .unwrap();
1138 let m2 = dbtxn
1139 .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1140 .unwrap();
1141 assert_eq!(m1[0].name, "testuser");
1142 assert_eq!(m2[0].name, "testuser");
1143
1144 ut1.groups = vec![gt1];
1146 dbtxn.update_account(&ut1, 0).unwrap();
1147
1148 let m1 = dbtxn
1150 .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1151 .unwrap();
1152 let m2 = dbtxn
1153 .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1154 .unwrap();
1155 assert_eq!(m1[0].name, "testuser");
1156 assert!(m2.is_empty());
1157
1158 assert!(dbtxn.commit().is_ok());
1159 }
1160
1161 #[tokio::test]
1162 async fn test_cache_db_group_rename_duplicate() {
1163 sketching::test_init();
1164 let db = Db::new("").expect("failed to create.");
1165 let mut dbtxn = db.write().await;
1166 assert!(dbtxn.migrate().is_ok());
1167
1168 let mut gt1 = GroupToken {
1169 provider: ProviderOrigin::System,
1170 name: "testgroup".to_string(),
1171 spn: "testgroup@example.com".to_string(),
1172 gidnumber: 2000,
1173 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1174 extra_keys: Default::default(),
1175 };
1176
1177 let gt2 = GroupToken {
1178 provider: ProviderOrigin::System,
1179 name: "testgroup".to_string(),
1180 spn: "testgroup@example.com".to_string(),
1181 gidnumber: 2001,
1182 uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1183 extra_keys: Default::default(),
1184 };
1185
1186 let id_name = Id::Name("testgroup".to_string());
1187 let id_name2 = Id::Name("testgroup2".to_string());
1188
1189 let r1 = dbtxn.get_group(&id_name).unwrap();
1191 assert!(r1.is_none());
1192
1193 dbtxn.update_group(>1, 0).unwrap();
1195 let r0 = dbtxn.get_group(&id_name).unwrap();
1196 assert_eq!(
1197 r0.unwrap().0.uuid,
1198 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1199 );
1200
1201 gt1.name = "testgroup2".to_string();
1203 gt1.spn = "testgroup2@example.com".to_string();
1204 dbtxn.update_group(>2, 0).unwrap();
1206 let r2 = dbtxn.get_group(&id_name).unwrap();
1207 assert_eq!(
1208 r2.unwrap().0.uuid,
1209 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1210 );
1211 let r3 = dbtxn.get_group(&id_name2).unwrap();
1212 assert!(r3.is_none());
1213
1214 dbtxn.update_group(>1, 0).unwrap();
1216
1217 let r4 = dbtxn.get_group(&id_name).unwrap();
1219 assert_eq!(
1220 r4.unwrap().0.uuid,
1221 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1222 );
1223 let r5 = dbtxn.get_group(&id_name2).unwrap();
1224 assert_eq!(
1225 r5.unwrap().0.uuid,
1226 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1227 );
1228
1229 assert!(dbtxn.commit().is_ok());
1230 }
1231
1232 #[tokio::test]
1233 async fn test_cache_db_account_rename_duplicate() {
1234 sketching::test_init();
1235 let db = Db::new("").expect("failed to create.");
1236 let mut dbtxn = db.write().await;
1237 assert!(dbtxn.migrate().is_ok());
1238
1239 let mut ut1 = UserToken {
1240 provider: ProviderOrigin::System,
1241 name: "testuser".to_string(),
1242 spn: "testuser@example.com".to_string(),
1243 displayname: "Test User".to_string(),
1244 gidnumber: 2000,
1245 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1246 shell: None,
1247 groups: Vec::new(),
1248 sshkeys: vec!["key-a".to_string()],
1249 valid: true,
1250 extra_keys: Default::default(),
1251 };
1252
1253 let ut2 = UserToken {
1254 provider: ProviderOrigin::System,
1255 name: "testuser".to_string(),
1256 spn: "testuser@example.com".to_string(),
1257 displayname: "Test User".to_string(),
1258 gidnumber: 2001,
1259 uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1260 shell: None,
1261 groups: Vec::new(),
1262 sshkeys: vec!["key-a".to_string()],
1263 valid: true,
1264 extra_keys: Default::default(),
1265 };
1266
1267 let id_name = Id::Name("testuser".to_string());
1268 let id_name2 = Id::Name("testuser2".to_string());
1269
1270 let r1 = dbtxn.get_account(&id_name).unwrap();
1272 assert!(r1.is_none());
1273
1274 dbtxn.update_account(&ut1, 0).unwrap();
1276 let r0 = dbtxn.get_account(&id_name).unwrap();
1277 assert_eq!(
1278 r0.unwrap().0.uuid,
1279 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1280 );
1281
1282 ut1.name = "testuser2".to_string();
1284 ut1.spn = "testuser2@example.com".to_string();
1285 dbtxn.update_account(&ut2, 0).unwrap();
1287 let r2 = dbtxn.get_account(&id_name).unwrap();
1288 assert_eq!(
1289 r2.unwrap().0.uuid,
1290 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1291 );
1292 let r3 = dbtxn.get_account(&id_name2).unwrap();
1293 assert!(r3.is_none());
1294
1295 dbtxn.update_account(&ut1, 0).unwrap();
1297
1298 let r4 = dbtxn.get_account(&id_name).unwrap();
1300 assert_eq!(
1301 r4.unwrap().0.uuid,
1302 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1303 );
1304 let r5 = dbtxn.get_account(&id_name2).unwrap();
1305 assert_eq!(
1306 r5.unwrap().0.uuid,
1307 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1308 );
1309
1310 assert!(dbtxn.commit().is_ok());
1311 }
1312}