1use std::convert::TryFrom;
2use std::fmt;
3
4use crate::idprovider::interface::{GroupToken, Id, UserToken};
5use async_trait::async_trait;
6use libc::umask;
7use rusqlite::{Connection, OptionalExtension};
8use tokio::sync::{Mutex, MutexGuard};
9use uuid::Uuid;
10
11use serde::{de::DeserializeOwned, Serialize};
12
13use kanidm_hsm_crypto::{LoadableHmacKey, LoadableMachineKey};
14
15const DBV_MAIN: &str = "main";
16const CACHE_SIZE: usize = 32 * ((1024 * 1024) / 4096);
19
20#[async_trait]
21pub trait Cache {
22 type Txn<'db>
23 where
24 Self: 'db;
25
26 async fn write<'db>(&'db self) -> Self::Txn<'db>;
27}
28
29#[async_trait]
30pub trait KeyStore {
31 type Txn<'db>
32 where
33 Self: 'db;
34
35 async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
36}
37
38#[derive(Debug)]
39pub enum CacheError {
40 Cryptography,
41 SerdeJson,
42 Parse,
43 Sqlite,
44 TooManyResults,
45 TransactionInvalidState,
46 Tpm,
47}
48
49pub struct Db {
50 conn: Mutex<Connection>,
51}
52
53pub struct DbTxn<'a> {
54 conn: MutexGuard<'a, Connection>,
55 committed: bool,
56}
57
58pub struct KeyStoreTxn<'a, 'b> {
59 db: &'b mut DbTxn<'a>,
60}
61
62impl<'a, 'b> From<&'b mut DbTxn<'a>> for KeyStoreTxn<'a, 'b> {
63 fn from(db: &'b mut DbTxn<'a>) -> Self {
64 Self { db }
65 }
66}
67
68#[derive(Debug)]
69pub enum DbError {
71 Sqlite,
72 Tpm,
73}
74
75impl Db {
76 pub fn new(path: &str) -> Result<Self, DbError> {
77 let before = unsafe { umask(0o0027) };
78 let conn = Connection::open(path).map_err(|e| {
79 error!(err = ?e, "rusqulite error");
80 DbError::Sqlite
81 })?;
82 let _ = unsafe { umask(before) };
83
84 conn.pragma_update(None, "journal_mode", "WAL")
86 .map_err(|error| {
87 error!(
88 "sqlite journal_mode=WAL error: {:?} db_path={:?}",
89 error, path
90 );
91 DbError::Sqlite
92 })?;
93
94 conn.pragma_update(None, "cache_size", CACHE_SIZE)
95 .map_err(|error| {
96 error!(
97 "sqlite cache_size={} error: {:?} db_path={:?}",
98 CACHE_SIZE, error, path
99 );
100 DbError::Sqlite
101 })?;
102
103 conn.set_prepared_statement_cache_capacity(32);
104
105 Ok(Db {
106 conn: Mutex::new(conn),
107 })
108 }
109}
110
111#[async_trait]
112impl Cache for Db {
113 type Txn<'db> = DbTxn<'db>;
114
115 #[allow(clippy::expect_used)]
116 async fn write<'db>(&'db self) -> Self::Txn<'db> {
117 let conn = self.conn.lock().await;
118 DbTxn::new(conn)
119 }
120}
121
122impl fmt::Debug for Db {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 write!(f, "Db {{}}")
125 }
126}
127
128impl<'a> DbTxn<'a> {
129 fn new(conn: MutexGuard<'a, Connection>) -> Self {
130 #[allow(clippy::expect_used)]
133 conn.execute("BEGIN TRANSACTION", [])
134 .expect("Unable to begin transaction!");
135 DbTxn {
136 committed: false,
137 conn,
138 }
139 }
140
141 fn sqlite_error(&self, msg: &str, error: &rusqlite::Error) -> CacheError {
143 error!(
144 "sqlite {} error: {:?} db_path={:?}",
145 msg,
146 error,
147 &self.conn.path()
148 );
149 CacheError::Sqlite
150 }
151
152 fn sqlite_transaction_error(
154 &self,
155 error: &rusqlite::Error,
156 _stmt: &rusqlite::Statement,
157 ) -> CacheError {
158 error!(
159 "sqlite transaction error={:?} db_path={:?}",
160 error,
161 &self.conn.path(),
162 );
163 CacheError::Sqlite
165 }
166
167 fn get_db_version(&self, key: &str) -> i64 {
168 self.conn
169 .query_row(
170 "SELECT version FROM db_version_t WHERE id = :id",
171 &[(":id", key)],
172 |row| row.get(0),
173 )
174 .unwrap_or({
175 0
177 })
178 }
179
180 fn set_db_version(&self, key: &str, v: i64) -> Result<(), CacheError> {
181 self.conn
182 .execute(
183 "INSERT OR REPLACE INTO db_version_t (id, version) VALUES(:id, :dbv)",
184 named_params! {
185 ":id": &key,
186 ":dbv": v,
187 },
188 )
189 .map(|_| ())
190 .map_err(|e| self.sqlite_error("set db_version_t", &e))
191 }
192
193 fn get_account_data_name(
194 &mut self,
195 account_id: &str,
196 ) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
197 let mut stmt = self.conn
198 .prepare(
199 "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
200 )
201 .map_err(|e| {
202 self.sqlite_error("select prepare", &e)
203 })?;
204
205 let data_iter = stmt
207 .query_map([account_id], |row| Ok((row.get(0)?, row.get(1)?)))
208 .map_err(|e| self.sqlite_error("query_map failure", &e))?;
209 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
210 .map(|v| v.map_err(|e| self.sqlite_error("map failure", &e)))
211 .collect();
212 data
213 }
214
215 fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
216 let mut stmt = self
217 .conn
218 .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
219 .map_err(|e| self.sqlite_error("select prepare", &e))?;
220
221 let data_iter = stmt
223 .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
224 .map_err(|e| self.sqlite_error("query_map", &e))?;
225 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
226 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
227 .collect();
228 data
229 }
230
231 fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
232 let mut stmt = self.conn
233 .prepare(
234 "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
235 )
236 .map_err(|e| {
237 self.sqlite_error("select prepare", &e)
238 })?;
239
240 let data_iter = stmt
242 .query_map([grp_id], |row| Ok((row.get(0)?, row.get(1)?)))
243 .map_err(|e| self.sqlite_error("query_map", &e))?;
244 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
245 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
246 .collect();
247 data
248 }
249
250 fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
251 let mut stmt = self
252 .conn
253 .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
254 .map_err(|e| self.sqlite_error("select prepare", &e))?;
255
256 let data_iter = stmt
258 .query_map(params![gid], |row| Ok((row.get(0)?, row.get(1)?)))
259 .map_err(|e| self.sqlite_error("query_map", &e))?;
260 let data: Result<Vec<(Vec<u8>, i64)>, _> = data_iter
261 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
262 .collect();
263 data
264 }
265}
266
267impl KeyStoreTxn<'_, '_> {
268 pub fn get_tagged_hsm_key<K: DeserializeOwned>(
269 &mut self,
270 tag: &str,
271 ) -> Result<Option<K>, CacheError> {
272 self.db.get_tagged_hsm_key(tag)
273 }
274
275 pub fn insert_tagged_hsm_key<K: Serialize>(
276 &mut self,
277 tag: &str,
278 key: &K,
279 ) -> Result<(), CacheError> {
280 self.db.insert_tagged_hsm_key(tag, key)
281 }
282
283 pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
284 self.db.delete_tagged_hsm_key(tag)
285 }
286}
287
288impl DbTxn<'_> {
289 fn get_tagged_hsm_key<K: DeserializeOwned>(
290 &mut self,
291 tag: &str,
292 ) -> Result<Option<K>, CacheError> {
293 let mut stmt = self
294 .conn
295 .prepare("SELECT value FROM hsm_data_t WHERE key = :key")
296 .map_err(|e| self.sqlite_error("select prepare", &e))?;
297
298 let data: Option<Vec<u8>> = stmt
299 .query_row(
300 named_params! {
301 ":key": tag
302 },
303 |row| row.get(0),
304 )
305 .optional()
306 .map_err(|e| self.sqlite_error("query_row", &e))?;
307
308 match data {
309 Some(d) => Ok(serde_json::from_slice(d.as_slice())
310 .map_err(|e| {
311 error!("json error -> {:?}", e);
312 })
313 .ok()),
314 None => Ok(None),
315 }
316 }
317
318 fn insert_tagged_hsm_key<K: Serialize>(
319 &mut self,
320 tag: &str,
321 key: &K,
322 ) -> Result<(), CacheError> {
323 let data = serde_json::to_vec(key).map_err(|e| {
324 error!("json error -> {:?}", e);
325 CacheError::SerdeJson
326 })?;
327
328 let mut stmt = self
329 .conn
330 .prepare("INSERT OR REPLACE INTO hsm_data_t (key, value) VALUES (:key, :value)")
331 .map_err(|e| self.sqlite_error("prepare", &e))?;
332
333 stmt.execute(named_params! {
334 ":key": tag,
335 ":value": &data,
336 })
337 .map(|r| {
338 trace!("insert -> {:?}", r);
339 })
340 .map_err(|e| self.sqlite_error("execute", &e))
341 }
342
343 fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
344 self.conn
345 .execute(
346 "DELETE FROM hsm_data_t where key = :key",
347 named_params! {
348 ":key": tag,
349 },
350 )
351 .map(|_| ())
352 .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))
353 }
354}
355
356impl DbTxn<'_> {
357 pub fn migrate(&mut self) -> Result<(), CacheError> {
358 self.conn
360 .execute(
361 "CREATE TABLE IF NOT EXISTS db_version_t (
362 id TEXT PRIMARY KEY,
363 version INTEGER
364 )",
365 [],
366 )
367 .map_err(|e| self.sqlite_error("db_version_t create", &e))?;
368
369 let db_version = self.get_db_version(DBV_MAIN);
370
371 if db_version < 1 {
372 self.conn
376 .execute(
377 "CREATE TABLE IF NOT EXISTS account_t (
378 uuid TEXT PRIMARY KEY,
379 name TEXT NOT NULL UNIQUE,
380 spn TEXT NOT NULL UNIQUE,
381 gidnumber INTEGER NOT NULL UNIQUE,
382 password BLOB,
383 token BLOB NOT NULL,
384 expiry NUMERIC NOT NULL
385 )
386 ",
387 [],
388 )
389 .map_err(|e| self.sqlite_error("account_t create", &e))?;
390
391 self.conn
392 .execute(
393 "CREATE TABLE IF NOT EXISTS group_t (
394 uuid TEXT PRIMARY KEY,
395 name TEXT NOT NULL UNIQUE,
396 spn TEXT NOT NULL UNIQUE,
397 gidnumber INTEGER NOT NULL UNIQUE,
398 token BLOB NOT NULL,
399 expiry NUMERIC NOT NULL
400 )
401 ",
402 [],
403 )
404 .map_err(|e| self.sqlite_error("group_t create", &e))?;
405
406 self.conn
413 .execute(
414 "CREATE TABLE IF NOT EXISTS memberof_t (
415 g_uuid TEXT,
416 a_uuid TEXT,
417 FOREIGN KEY(g_uuid) REFERENCES group_t(uuid) DEFERRABLE INITIALLY DEFERRED,
418 FOREIGN KEY(a_uuid) REFERENCES account_t(uuid) ON DELETE CASCADE
419 )
420 ",
421 [],
422 )
423 .map_err(|e| self.sqlite_error("memberof_t create error", &e))?;
424
425 self.conn
428 .execute(
429 "CREATE TABLE IF NOT EXISTS hsm_int_t (
430 key TEXT PRIMARY KEY,
431 value BLOB NOT NULL
432 )
433 ",
434 [],
435 )
436 .map_err(|e| self.sqlite_error("hsm_int_t create error", &e))?;
437
438 self.conn
439 .execute(
440 "CREATE TABLE IF NOT EXISTS hsm_data_t (
441 key TEXT PRIMARY KEY,
442 value BLOB NOT NULL
443 )
444 ",
445 [],
446 )
447 .map_err(|e| self.sqlite_error("hsm_data_t create error", &e))?;
448
449 self.clear_hsm()?;
451 }
452
453 self.set_db_version(DBV_MAIN, 1)?;
454
455 Ok(())
456 }
457
458 pub fn commit(mut self) -> Result<(), CacheError> {
459 if self.committed {
460 error!("Invalid state, SQL transaction was already committed!");
461 return Err(CacheError::TransactionInvalidState);
462 }
463 self.committed = true;
464
465 self.conn
466 .execute("COMMIT TRANSACTION", [])
467 .map(|_| ())
468 .map_err(|e| self.sqlite_error("commit", &e))
469 }
470
471 pub fn invalidate(&mut self) -> Result<(), CacheError> {
472 self.conn
473 .execute("UPDATE group_t SET expiry = 0", [])
474 .map_err(|e| self.sqlite_error("update group_t", &e))?;
475
476 self.conn
477 .execute("UPDATE account_t SET expiry = 0", [])
478 .map_err(|e| self.sqlite_error("update account_t", &e))?;
479
480 Ok(())
481 }
482
483 pub fn clear(&mut self) -> Result<(), CacheError> {
484 self.conn
485 .execute("DELETE FROM memberof_t", [])
486 .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
487
488 self.conn
489 .execute("DELETE FROM group_t", [])
490 .map_err(|e| self.sqlite_error("delete group_t", &e))?;
491
492 self.conn
493 .execute("DELETE FROM account_t", [])
494 .map_err(|e| self.sqlite_error("delete group_t", &e))?;
495
496 Ok(())
497 }
498
499 pub fn clear_hsm(&mut self) -> Result<(), CacheError> {
500 self.clear()?;
501
502 self.conn
503 .execute("DELETE FROM hsm_int_t", [])
504 .map_err(|e| self.sqlite_error("delete hsm_int_t", &e))?;
505
506 self.conn
507 .execute("DELETE FROM hsm_data_t", [])
508 .map_err(|e| self.sqlite_error("delete hsm_data_t", &e))?;
509
510 Ok(())
511 }
512
513 pub fn get_hsm_machine_key(&mut self) -> Result<Option<LoadableMachineKey>, CacheError> {
514 let mut stmt = self
515 .conn
516 .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
517 .map_err(|e| self.sqlite_error("select prepare", &e))?;
518
519 let data: Option<Vec<u8>> = stmt
520 .query_row([], |row| row.get(0))
521 .optional()
522 .map_err(|e| self.sqlite_error("query_row", &e))?;
523
524 match data {
525 Some(d) => Ok(serde_json::from_slice(d.as_slice())
526 .map_err(|e| {
527 error!("json error -> {:?}", e);
528 })
529 .ok()),
530 None => Ok(None),
531 }
532 }
533
534 pub fn insert_hsm_machine_key(
535 &mut self,
536 machine_key: &LoadableMachineKey,
537 ) -> Result<(), CacheError> {
538 let data = serde_json::to_vec(machine_key).map_err(|e| {
539 error!("insert_hsm_machine_key json error -> {:?}", e);
540 CacheError::SerdeJson
541 })?;
542
543 let mut stmt = self
544 .conn
545 .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
546 .map_err(|e| self.sqlite_error("prepare", &e))?;
547
548 stmt.execute(named_params! {
549 ":key": "mk",
550 ":value": &data,
551 })
552 .map(|r| {
553 trace!("insert -> {:?}", r);
554 })
555 .map_err(|e| self.sqlite_error("execute", &e))
556 }
557
558 pub fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacKey>, CacheError> {
559 let mut stmt = self
560 .conn
561 .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
562 .map_err(|e| self.sqlite_error("select prepare", &e))?;
563
564 let data: Option<Vec<u8>> = stmt
565 .query_row([], |row| row.get(0))
566 .optional()
567 .map_err(|e| self.sqlite_error("query_row", &e))?;
568
569 match data {
570 Some(d) => Ok(serde_json::from_slice(d.as_slice())
571 .map_err(|e| {
572 error!("json error -> {:?}", e);
573 })
574 .ok()),
575 None => Ok(None),
576 }
577 }
578
579 pub fn insert_hsm_hmac_key(&mut self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError> {
580 let data = serde_json::to_vec(hmac_key).map_err(|e| {
581 error!("insert_hsm_hmac_key json error -> {:?}", e);
582 CacheError::SerdeJson
583 })?;
584
585 let mut stmt = self
586 .conn
587 .prepare("INSERT OR REPLACE INTO hsm_int_t (key, value) VALUES (:key, :value)")
588 .map_err(|e| self.sqlite_error("prepare", &e))?;
589
590 stmt.execute(named_params! {
591 ":key": "hmac",
592 ":value": &data,
593 })
594 .map(|r| {
595 trace!("insert -> {:?}", r);
596 })
597 .map_err(|e| self.sqlite_error("execute", &e))
598 }
599
600 pub fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
601 let data = match account_id {
602 Id::Name(n) => self.get_account_data_name(n.as_str()),
603 Id::Gid(g) => self.get_account_data_gid(*g),
604 }?;
605
606 if data.len() >= 2 {
608 error!("invalid db state, multiple entries matched query?");
609 return Err(CacheError::TooManyResults);
610 }
611
612 if let Some((token, expiry)) = data.first() {
613 match serde_json::from_slice(token.as_slice()) {
617 Ok(t) => {
618 let e = u64::try_from(*expiry).map_err(|e| {
619 error!("u64 convert error -> {:?}", e);
620 CacheError::Parse
621 })?;
622 Ok(Some((t, e)))
623 }
624 Err(e) => {
625 warn!("recoverable - json error -> {:?}", e);
626 Ok(None)
627 }
628 }
629 } else {
630 Ok(None)
631 }
632 }
633
634 pub fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
635 let mut stmt = self
636 .conn
637 .prepare("SELECT token FROM account_t")
638 .map_err(|e| self.sqlite_error("select prepare", &e))?;
639
640 let data_iter = stmt
641 .query_map([], |row| row.get(0))
642 .map_err(|e| self.sqlite_error("query_map", &e))?;
643 let data: Result<Vec<Vec<u8>>, _> = data_iter
644 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
645 .collect();
646
647 let data = data?;
648
649 Ok(data
650 .iter()
651 .filter_map(|token| {
653 serde_json::from_slice(token.as_slice())
655 .map_err(|e| {
656 warn!("get_accounts json error -> {:?}", e);
657 })
658 .ok()
659 })
660 .collect())
661 }
662
663 pub fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
664 let data = serde_json::to_vec(account).map_err(|e| {
665 error!("update_account json error -> {:?}", e);
666 CacheError::SerdeJson
667 })?;
668 let expire = i64::try_from(expire).map_err(|e| {
669 error!("update_account i64 conversion error -> {:?}", e);
670 CacheError::Parse
671 })?;
672
673 let account_uuid = account.uuid.as_hyphenated().to_string();
677
678 self.conn.execute("DELETE FROM account_t WHERE NOT uuid = :uuid AND (name = :name OR spn = :spn OR gidnumber = :gidnumber)",
680 named_params!{
681 ":uuid": &account_uuid,
682 ":name": &account.name,
683 ":spn": &account.spn,
684 ":gidnumber": &account.gidnumber,
685 }
686 )
687 .map_err(|e| {
688 self.sqlite_error("delete account_t duplicate", &e)
689 })
690 .map(|_| ())?;
691
692 let updated = self.conn.execute(
693 "UPDATE account_t SET name=:name, spn=:spn, gidnumber=:gidnumber, token=:token, expiry=:expiry WHERE uuid = :uuid",
694 named_params!{
695 ":uuid": &account_uuid,
696 ":name": &account.name,
697 ":spn": &account.spn,
698 ":gidnumber": &account.gidnumber,
699 ":token": &data,
700 ":expiry": &expire,
701 }
702 )
703 .map_err(|e| {
704 self.sqlite_error("delete account_t duplicate", &e)
705 })?;
706
707 if updated == 0 {
708 let mut stmt = self.conn
709 .prepare("INSERT INTO account_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry) ON CONFLICT(uuid) DO UPDATE SET name=excluded.name, spn=excluded.name, gidnumber=excluded.gidnumber, token=excluded.token, expiry=excluded.expiry")
710 .map_err(|e| {
711 self.sqlite_error("prepare", &e)
712 })?;
713
714 stmt.execute(named_params! {
715 ":uuid": &account_uuid,
716 ":name": &account.name,
717 ":spn": &account.spn,
718 ":gidnumber": &account.gidnumber,
719 ":token": &data,
720 ":expiry": &expire,
721 })
722 .map(|r| {
723 trace!("insert -> {:?}", r);
724 })
725 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
726 }
727
728 let mut stmt = self
732 .conn
733 .prepare("DELETE FROM memberof_t WHERE a_uuid = :a_uuid")
734 .map_err(|e| self.sqlite_error("prepare", &e))?;
735
736 stmt.execute([&account_uuid])
737 .map(|r| {
738 trace!("delete memberships -> {:?}", r);
739 })
740 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))?;
741
742 let mut stmt = self
743 .conn
744 .prepare("INSERT INTO memberof_t (a_uuid, g_uuid) VALUES (:a_uuid, :g_uuid)")
745 .map_err(|e| self.sqlite_error("prepare", &e))?;
746 account.groups.iter().try_for_each(|g| {
748 stmt.execute(named_params! {
749 ":a_uuid": &account_uuid,
750 ":g_uuid": &g.uuid.as_hyphenated().to_string(),
751 })
752 .map(|r| {
753 trace!("insert membership -> {:?}", r);
754 })
755 .map_err(|error| self.sqlite_transaction_error(&error, &stmt))
756 })
757 }
758
759 pub fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
760 let account_uuid = a_uuid.as_hyphenated().to_string();
761
762 self.conn
763 .execute(
764 "DELETE FROM memberof_t WHERE a_uuid = :a_uuid",
765 params![&account_uuid],
766 )
767 .map(|_| ())
768 .map_err(|e| self.sqlite_error("account_t memberof_t cascade delete", &e))?;
769
770 self.conn
771 .execute(
772 "DELETE FROM account_t WHERE uuid = :a_uuid",
773 params![&account_uuid],
774 )
775 .map(|_| ())
776 .map_err(|e| self.sqlite_error("account_t delete", &e))
777 }
778
779 pub fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
780 let data = match grp_id {
781 Id::Name(n) => self.get_group_data_name(n.as_str()),
782 Id::Gid(g) => self.get_group_data_gid(*g),
783 }?;
784
785 if data.len() >= 2 {
787 error!("invalid db state, multiple entries matched query?");
788 return Err(CacheError::TooManyResults);
789 }
790
791 if let Some((token, expiry)) = data.first() {
792 match serde_json::from_slice(token.as_slice()) {
796 Ok(t) => {
797 let e = u64::try_from(*expiry).map_err(|e| {
798 error!("u64 convert error -> {:?}", e);
799 CacheError::Parse
800 })?;
801 Ok(Some((t, e)))
802 }
803 Err(e) => {
804 warn!("recoverable - json error -> {:?}", e);
805 Ok(None)
806 }
807 }
808 } else {
809 Ok(None)
810 }
811 }
812
813 pub fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
814 let mut stmt = self
815 .conn
816 .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
817 .map_err(|e| {
818 self.sqlite_error("select prepare", &e)
819 })?;
820
821 let data_iter = stmt
822 .query_map([g_uuid.as_hyphenated().to_string()], |row| row.get(0))
823 .map_err(|e| self.sqlite_error("query_map", &e))?;
824 let data: Result<Vec<Vec<u8>>, _> = data_iter
825 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
826 .collect();
827
828 let data = data?;
829
830 data.iter()
831 .map(|token| {
832 serde_json::from_slice(token.as_slice()).map_err(|e| {
835 error!("json error -> {:?}", e);
836 CacheError::SerdeJson
837 })
838 })
839 .collect()
840 }
841
842 pub fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
843 let mut stmt = self
844 .conn
845 .prepare("SELECT token FROM group_t")
846 .map_err(|e| self.sqlite_error("select prepare", &e))?;
847
848 let data_iter = stmt
849 .query_map([], |row| row.get(0))
850 .map_err(|e| self.sqlite_error("query_map", &e))?;
851 let data: Result<Vec<Vec<u8>>, _> = data_iter
852 .map(|v| v.map_err(|e| self.sqlite_error("map", &e)))
853 .collect();
854
855 let data = data?;
856
857 Ok(data
858 .iter()
859 .filter_map(|token| {
860 serde_json::from_slice(token.as_slice())
863 .map_err(|e| {
864 error!("json error -> {:?}", e);
865 })
866 .ok()
867 })
868 .collect())
869 }
870
871 pub fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
872 let data = serde_json::to_vec(grp).map_err(|e| {
873 error!("json error -> {:?}", e);
874 CacheError::SerdeJson
875 })?;
876 let expire = i64::try_from(expire).map_err(|e| {
877 error!("i64 convert error -> {:?}", e);
878 CacheError::Parse
879 })?;
880
881 let mut stmt = self.conn
882 .prepare("INSERT OR REPLACE INTO group_t (uuid, name, spn, gidnumber, token, expiry) VALUES (:uuid, :name, :spn, :gidnumber, :token, :expiry)")
883 .map_err(|e| {
884 self.sqlite_error("prepare", &e)
885 })?;
886
887 stmt.execute(named_params! {
889 ":uuid": &grp.uuid.as_hyphenated().to_string(),
890 ":name": &grp.name,
891 ":spn": &grp.spn,
892 ":gidnumber": &grp.gidnumber,
893 ":token": &data,
894 ":expiry": &expire,
895 })
896 .map(|r| {
897 trace!("insert -> {:?}", r);
898 })
899 .map_err(|e| self.sqlite_error("execute", &e))
900 }
901
902 pub fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
903 let group_uuid = g_uuid.as_hyphenated().to_string();
904 self.conn
905 .execute(
906 "DELETE FROM memberof_t WHERE g_uuid = :g_uuid",
907 [&group_uuid],
908 )
909 .map(|_| ())
910 .map_err(|e| self.sqlite_error("group_t memberof_t cascade delete", &e))?;
911 self.conn
912 .execute("DELETE FROM group_t WHERE uuid = :g_uuid", [&group_uuid])
913 .map(|_| ())
914 .map_err(|e| self.sqlite_error("group_t delete", &e))
915 }
916}
917
918impl fmt::Debug for DbTxn<'_> {
919 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
920 write!(f, "DbTxn {{}}")
921 }
922}
923
924impl Drop for DbTxn<'_> {
925 fn drop(&mut self) {
927 if !self.committed {
928 #[allow(clippy::expect_used)]
930 self.conn
931 .execute("ROLLBACK TRANSACTION", [])
932 .expect("Unable to rollback transaction! Can not proceed!!!");
933 }
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::{Cache, Db};
940 use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
941
942 #[tokio::test]
943 async fn test_cache_db_account_basic() {
944 sketching::test_init();
945 let db = Db::new("").expect("failed to create.");
946 let mut dbtxn = db.write().await;
947 assert!(dbtxn.migrate().is_ok());
948
949 let mut ut1 = UserToken {
950 provider: ProviderOrigin::System,
951 name: "testuser".to_string(),
952 spn: "testuser@example.com".to_string(),
953 displayname: "Test User".to_string(),
954 gidnumber: 2000,
955 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
956 shell: None,
957 groups: Vec::new(),
958 sshkeys: vec!["key-a".to_string()],
959 valid: true,
960 extra_keys: Default::default(),
961 };
962
963 let id_name = Id::Name("testuser".to_string());
964 let id_name2 = Id::Name("testuser2".to_string());
965 let id_spn = Id::Name("testuser@example.com".to_string());
966 let id_spn2 = Id::Name("testuser2@example.com".to_string());
967 let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
968 let id_gid = Id::Gid(2000);
969
970 let r1 = dbtxn.get_account(&id_name).unwrap();
972 assert!(r1.is_none());
973 let r2 = dbtxn.get_account(&id_spn).unwrap();
974 assert!(r2.is_none());
975 let r3 = dbtxn.get_account(&id_uuid).unwrap();
976 assert!(r3.is_none());
977 let r4 = dbtxn.get_account(&id_gid).unwrap();
978 assert!(r4.is_none());
979
980 dbtxn.update_account(&ut1, 0).unwrap();
982
983 let r1 = dbtxn.get_account(&id_name).unwrap();
985 assert!(r1.is_some());
986 let r2 = dbtxn.get_account(&id_spn).unwrap();
987 assert!(r2.is_some());
988 let r3 = dbtxn.get_account(&id_uuid).unwrap();
989 assert!(r3.is_some());
990 let r4 = dbtxn.get_account(&id_gid).unwrap();
991 assert!(r4.is_some());
992
993 ut1.name = "testuser2".to_string();
995 ut1.spn = "testuser2@example.com".to_string();
996 dbtxn.update_account(&ut1, 0).unwrap();
997
998 let r1 = dbtxn.get_account(&id_name).unwrap();
1000 assert!(r1.is_none());
1001 let r2 = dbtxn.get_account(&id_spn).unwrap();
1002 assert!(r2.is_none());
1003 let r1 = dbtxn.get_account(&id_name2).unwrap();
1004 assert!(r1.is_some());
1005 let r2 = dbtxn.get_account(&id_spn2).unwrap();
1006 assert!(r2.is_some());
1007 let r3 = dbtxn.get_account(&id_uuid).unwrap();
1008 assert!(r3.is_some());
1009 let r4 = dbtxn.get_account(&id_gid).unwrap();
1010 assert!(r4.is_some());
1011
1012 assert!(dbtxn.clear().is_ok());
1014
1015 let r1 = dbtxn.get_account(&id_name2).unwrap();
1017 assert!(r1.is_none());
1018 let r2 = dbtxn.get_account(&id_spn2).unwrap();
1019 assert!(r2.is_none());
1020 let r3 = dbtxn.get_account(&id_uuid).unwrap();
1021 assert!(r3.is_none());
1022 let r4 = dbtxn.get_account(&id_gid).unwrap();
1023 assert!(r4.is_none());
1024
1025 assert!(dbtxn.commit().is_ok());
1026 }
1027
1028 #[tokio::test]
1029 async fn test_cache_db_group_basic() {
1030 sketching::test_init();
1031 let db = Db::new("").expect("failed to create.");
1032 let mut dbtxn = db.write().await;
1033 assert!(dbtxn.migrate().is_ok());
1034
1035 let mut gt1 = GroupToken {
1036 provider: ProviderOrigin::System,
1037 name: "testgroup".to_string(),
1038 spn: "testgroup@example.com".to_string(),
1039 gidnumber: 2000,
1040 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1041 extra_keys: Default::default(),
1042 };
1043
1044 let id_name = Id::Name("testgroup".to_string());
1045 let id_name2 = Id::Name("testgroup2".to_string());
1046 let id_spn = Id::Name("testgroup@example.com".to_string());
1047 let id_spn2 = Id::Name("testgroup2@example.com".to_string());
1048 let id_uuid = Id::Name("0302b99c-f0f6-41ab-9492-852692b0fd16".to_string());
1049 let id_gid = Id::Gid(2000);
1050
1051 let r1 = dbtxn.get_group(&id_name).unwrap();
1053 assert!(r1.is_none());
1054 let r2 = dbtxn.get_group(&id_spn).unwrap();
1055 assert!(r2.is_none());
1056 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1057 assert!(r3.is_none());
1058 let r4 = dbtxn.get_group(&id_gid).unwrap();
1059 assert!(r4.is_none());
1060
1061 dbtxn.update_group(>1, 0).unwrap();
1063 let r1 = dbtxn.get_group(&id_name).unwrap();
1064 assert!(r1.is_some());
1065 let r2 = dbtxn.get_group(&id_spn).unwrap();
1066 assert!(r2.is_some());
1067 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1068 assert!(r3.is_some());
1069 let r4 = dbtxn.get_group(&id_gid).unwrap();
1070 assert!(r4.is_some());
1071
1072 gt1.name = "testgroup2".to_string();
1074 gt1.spn = "testgroup2@example.com".to_string();
1075 dbtxn.update_group(>1, 0).unwrap();
1076 let r1 = dbtxn.get_group(&id_name).unwrap();
1077 assert!(r1.is_none());
1078 let r2 = dbtxn.get_group(&id_spn).unwrap();
1079 assert!(r2.is_none());
1080 let r1 = dbtxn.get_group(&id_name2).unwrap();
1081 assert!(r1.is_some());
1082 let r2 = dbtxn.get_group(&id_spn2).unwrap();
1083 assert!(r2.is_some());
1084 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1085 assert!(r3.is_some());
1086 let r4 = dbtxn.get_group(&id_gid).unwrap();
1087 assert!(r4.is_some());
1088
1089 assert!(dbtxn.clear().is_ok());
1091
1092 let r1 = dbtxn.get_group(&id_name2).unwrap();
1094 assert!(r1.is_none());
1095 let r2 = dbtxn.get_group(&id_spn2).unwrap();
1096 assert!(r2.is_none());
1097 let r3 = dbtxn.get_group(&id_uuid).unwrap();
1098 assert!(r3.is_none());
1099 let r4 = dbtxn.get_group(&id_gid).unwrap();
1100 assert!(r4.is_none());
1101
1102 assert!(dbtxn.commit().is_ok());
1103 }
1104
1105 #[tokio::test]
1106 async fn test_cache_db_account_group_update() {
1107 sketching::test_init();
1108 let db = Db::new("").expect("failed to create.");
1109 let mut dbtxn = db.write().await;
1110 assert!(dbtxn.migrate().is_ok());
1111
1112 let gt1 = GroupToken {
1113 provider: ProviderOrigin::System,
1114 name: "testuser".to_string(),
1115 spn: "testuser@example.com".to_string(),
1116 gidnumber: 2000,
1117 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1118 extra_keys: Default::default(),
1119 };
1120
1121 let gt2 = GroupToken {
1122 provider: ProviderOrigin::System,
1123 name: "testgroup".to_string(),
1124 spn: "testgroup@example.com".to_string(),
1125 gidnumber: 2001,
1126 uuid: uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"),
1127 extra_keys: Default::default(),
1128 };
1129
1130 let mut ut1 = UserToken {
1131 provider: ProviderOrigin::System,
1132 name: "testuser".to_string(),
1133 spn: "testuser@example.com".to_string(),
1134 displayname: "Test User".to_string(),
1135 gidnumber: 2000,
1136 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1137 shell: None,
1138 groups: vec![gt1.clone(), gt2],
1139 sshkeys: vec!["key-a".to_string()],
1140 valid: true,
1141 extra_keys: Default::default(),
1142 };
1143
1144 ut1.groups.iter().for_each(|g| {
1146 dbtxn.update_group(g, 0).unwrap();
1147 });
1148
1149 dbtxn.update_account(&ut1, 0).unwrap();
1151
1152 let m1 = dbtxn
1154 .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1155 .unwrap();
1156 let m2 = dbtxn
1157 .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1158 .unwrap();
1159 assert_eq!(m1[0].name, "testuser");
1160 assert_eq!(m2[0].name, "testuser");
1161
1162 ut1.groups = vec![gt1];
1164 dbtxn.update_account(&ut1, 0).unwrap();
1165
1166 let m1 = dbtxn
1168 .get_group_members(uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"))
1169 .unwrap();
1170 let m2 = dbtxn
1171 .get_group_members(uuid::uuid!("b500be97-8552-42a5-aca0-668bc5625705"))
1172 .unwrap();
1173 assert_eq!(m1[0].name, "testuser");
1174 assert!(m2.is_empty());
1175
1176 assert!(dbtxn.commit().is_ok());
1177 }
1178
1179 #[tokio::test]
1180 async fn test_cache_db_group_rename_duplicate() {
1181 sketching::test_init();
1182 let db = Db::new("").expect("failed to create.");
1183 let mut dbtxn = db.write().await;
1184 assert!(dbtxn.migrate().is_ok());
1185
1186 let mut gt1 = GroupToken {
1187 provider: ProviderOrigin::System,
1188 name: "testgroup".to_string(),
1189 spn: "testgroup@example.com".to_string(),
1190 gidnumber: 2000,
1191 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1192 extra_keys: Default::default(),
1193 };
1194
1195 let gt2 = GroupToken {
1196 provider: ProviderOrigin::System,
1197 name: "testgroup".to_string(),
1198 spn: "testgroup@example.com".to_string(),
1199 gidnumber: 2001,
1200 uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1201 extra_keys: Default::default(),
1202 };
1203
1204 let id_name = Id::Name("testgroup".to_string());
1205 let id_name2 = Id::Name("testgroup2".to_string());
1206
1207 let r1 = dbtxn.get_group(&id_name).unwrap();
1209 assert!(r1.is_none());
1210
1211 dbtxn.update_group(>1, 0).unwrap();
1213 let r0 = dbtxn.get_group(&id_name).unwrap();
1214 assert_eq!(
1215 r0.unwrap().0.uuid,
1216 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1217 );
1218
1219 gt1.name = "testgroup2".to_string();
1221 gt1.spn = "testgroup2@example.com".to_string();
1222 dbtxn.update_group(>2, 0).unwrap();
1224 let r2 = dbtxn.get_group(&id_name).unwrap();
1225 assert_eq!(
1226 r2.unwrap().0.uuid,
1227 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1228 );
1229 let r3 = dbtxn.get_group(&id_name2).unwrap();
1230 assert!(r3.is_none());
1231
1232 dbtxn.update_group(>1, 0).unwrap();
1234
1235 let r4 = dbtxn.get_group(&id_name).unwrap();
1237 assert_eq!(
1238 r4.unwrap().0.uuid,
1239 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1240 );
1241 let r5 = dbtxn.get_group(&id_name2).unwrap();
1242 assert_eq!(
1243 r5.unwrap().0.uuid,
1244 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1245 );
1246
1247 assert!(dbtxn.commit().is_ok());
1248 }
1249
1250 #[tokio::test]
1251 async fn test_cache_db_account_rename_duplicate() {
1252 sketching::test_init();
1253 let db = Db::new("").expect("failed to create.");
1254 let mut dbtxn = db.write().await;
1255 assert!(dbtxn.migrate().is_ok());
1256
1257 let mut ut1 = UserToken {
1258 provider: ProviderOrigin::System,
1259 name: "testuser".to_string(),
1260 spn: "testuser@example.com".to_string(),
1261 displayname: "Test User".to_string(),
1262 gidnumber: 2000,
1263 uuid: uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"),
1264 shell: None,
1265 groups: Vec::new(),
1266 sshkeys: vec!["key-a".to_string()],
1267 valid: true,
1268 extra_keys: Default::default(),
1269 };
1270
1271 let ut2 = UserToken {
1272 provider: ProviderOrigin::System,
1273 name: "testuser".to_string(),
1274 spn: "testuser@example.com".to_string(),
1275 displayname: "Test User".to_string(),
1276 gidnumber: 2001,
1277 uuid: uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b"),
1278 shell: None,
1279 groups: Vec::new(),
1280 sshkeys: vec!["key-a".to_string()],
1281 valid: true,
1282 extra_keys: Default::default(),
1283 };
1284
1285 let id_name = Id::Name("testuser".to_string());
1286 let id_name2 = Id::Name("testuser2".to_string());
1287
1288 let r1 = dbtxn.get_account(&id_name).unwrap();
1290 assert!(r1.is_none());
1291
1292 dbtxn.update_account(&ut1, 0).unwrap();
1294 let r0 = dbtxn.get_account(&id_name).unwrap();
1295 assert_eq!(
1296 r0.unwrap().0.uuid,
1297 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1298 );
1299
1300 ut1.name = "testuser2".to_string();
1302 ut1.spn = "testuser2@example.com".to_string();
1303 dbtxn.update_account(&ut2, 0).unwrap();
1305 let r2 = dbtxn.get_account(&id_name).unwrap();
1306 assert_eq!(
1307 r2.unwrap().0.uuid,
1308 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1309 );
1310 let r3 = dbtxn.get_account(&id_name2).unwrap();
1311 assert!(r3.is_none());
1312
1313 dbtxn.update_account(&ut1, 0).unwrap();
1315
1316 let r4 = dbtxn.get_account(&id_name).unwrap();
1318 assert_eq!(
1319 r4.unwrap().0.uuid,
1320 uuid::uuid!("799123b2-3802-4b19-b0b8-1ffae2aa9a4b")
1321 );
1322 let r5 = dbtxn.get_account(&id_name2).unwrap();
1323 assert_eq!(
1324 r5.unwrap().0.uuid,
1325 uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16")
1326 );
1327
1328 assert!(dbtxn.commit().is_ok());
1329 }
1330}