kanidmd_lib/plugins/
oauth2.rs

1use crate::event::{CreateEvent, ModifyEvent};
2use crate::plugins::Plugin;
3use crate::prelude::*;
4use crate::utils::password_from_random;
5use crate::valueset::ValueSetUuid;
6use compact_jwt::{crypto::JwsRs256Signer, JwsEs256Signer};
7use std::sync::Arc;
8
9pub struct OAuth2 {}
10
11impl Plugin for OAuth2 {
12    fn id() -> &'static str {
13        "plugin_oauth2"
14    }
15
16    #[instrument(level = "debug", name = "oauth2_pre_create_transform", skip_all)]
17    fn pre_create_transform(
18        qs: &mut QueryServerWriteTransaction,
19        cand: &mut Vec<Entry<EntryInvalid, EntryNew>>,
20        _ce: &CreateEvent,
21    ) -> Result<(), OperationError> {
22        Self::modify_inner(qs, cand)
23    }
24
25    #[instrument(level = "debug", name = "oauth2_pre_modify", skip_all)]
26    fn pre_modify(
27        qs: &mut QueryServerWriteTransaction,
28        _pre_cand: &[Arc<EntrySealedCommitted>],
29        cand: &mut Vec<Entry<EntryInvalid, EntryCommitted>>,
30        _me: &ModifyEvent,
31    ) -> Result<(), OperationError> {
32        Self::modify_inner(qs, cand)
33    }
34
35    #[instrument(level = "debug", name = "oauth2_pre_batch_modify", skip_all)]
36    fn pre_batch_modify(
37        qs: &mut QueryServerWriteTransaction,
38        _pre_cand: &[Arc<EntrySealedCommitted>],
39        cand: &mut Vec<Entry<EntryInvalid, EntryCommitted>>,
40        _me: &BatchModifyEvent,
41    ) -> Result<(), OperationError> {
42        Self::modify_inner(qs, cand)
43    }
44}
45
46impl OAuth2 {
47    fn modify_inner<T: Clone>(
48        qs: &mut QueryServerWriteTransaction,
49        cand: &mut [Entry<EntryInvalid, T>],
50    ) -> Result<(), OperationError> {
51        let domain_level = qs.get_domain_version();
52
53        cand.iter_mut()
54            .filter(|entry| {
55                entry.attribute_equality(Attribute::Class, &EntryClass::OAuth2Account.into())
56            })
57            .for_each(|entry| {
58                if entry
59                    .get_ava_set(Attribute::OAuth2AccountCredentialUuid)
60                    .is_none()
61                {
62                    entry.set_ava_set(
63                        &Attribute::OAuth2AccountCredentialUuid,
64                        ValueSetUuid::new(Uuid::new_v4()),
65                    )
66                }
67            });
68
69        // Populate attributes into the oauth2 clients.
70        cand.iter_mut()
71            .filter(|entry| {
72                entry.attribute_equality(Attribute::Class, &EntryClass::OAuth2ResourceServer.into())
73            })
74            .try_for_each(|entry| {
75                // Regenerate the basic secret, if needed
76                if entry.attribute_equality(Attribute::Class, &EntryClass::OAuth2ResourceServerBasic.into()) &&
77                    !entry.attribute_pres(Attribute::OAuth2RsBasicSecret) {
78                        security_info!("regenerating oauth2 basic secret");
79                        let v = Value::SecretValue(password_from_random());
80                        entry.add_ava(Attribute::OAuth2RsBasicSecret, v);
81                }
82
83            let has_rs256 = entry.get_ava_single_bool(Attribute::OAuth2JwtLegacyCryptoEnable).unwrap_or(false);
84
85            if domain_level >= DOMAIN_LEVEL_10 {
86                debug!("Generating OAuth2 Key Object");
87                // OAuth2 now requires a KeyObject, configure it now.
88                entry.add_ava(Attribute::Class, EntryClass::KeyObject.to_value());
89                entry.add_ava(Attribute::Class, EntryClass::KeyObjectJwtEs256.to_value());
90                entry.add_ava(Attribute::Class, EntryClass::KeyObjectJweA128GCM.to_value());
91                if has_rs256 {
92                    entry.add_ava(Attribute::Class, EntryClass::KeyObjectJwtRs256.to_value());
93                }
94            } else {
95                if !entry.attribute_pres(Attribute::OAuth2RsTokenKey) {
96                    security_info!("regenerating oauth2 token key");
97                    let k = password_from_random();
98                    let v = Value::new_secret_str(&k);
99                    entry.add_ava(Attribute::OAuth2RsTokenKey, v);
100                }
101                if !entry.attribute_pres(Attribute::Es256PrivateKeyDer) {
102                    security_info!("regenerating oauth2 es256 private key");
103                    let der = JwsEs256Signer::generate_es256()
104                        .and_then(|jws| jws.private_key_to_der())
105                        .map_err(|e| {
106                            admin_error!(err = ?e, "Unable to generate ES256 JwsSigner private key");
107                            OperationError::CryptographyError
108                        })?;
109                    let v = Value::new_privatebinary(&der);
110                    entry.add_ava(Attribute::Es256PrivateKeyDer, v);
111                }
112                    if has_rs256 && !entry.attribute_pres(Attribute::Rs256PrivateKeyDer) {
113                    security_info!("regenerating oauth2 legacy rs256 private key");
114                    let der = JwsRs256Signer::generate_rs256()
115                        .and_then(|jws| jws.private_key_to_der())
116                        .map_err(|e| {
117                            admin_error!(err = ?e, "Unable to generate Legacy RS256 JwsSigner private key");
118                            OperationError::CryptographyError
119                        })?;
120                    let v = Value::new_privatebinary(&der);
121                    entry.add_ava(Attribute::Rs256PrivateKeyDer, v);
122                }
123            }
124
125            Ok(())
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use crate::prelude::*;
133
134    #[test]
135    fn test_pre_create_oauth2_secrets() {
136        let preload: Vec<Entry<EntryInit, EntryNew>> = Vec::with_capacity(0);
137
138        let uuid = Uuid::new_v4();
139        let e: Entry<EntryInit, EntryNew> = entry_init!(
140            (Attribute::Class, EntryClass::Object.to_value()),
141            (Attribute::Class, EntryClass::Account.to_value()),
142            (
143                Attribute::Class,
144                EntryClass::OAuth2ResourceServer.to_value()
145            ),
146            (
147                Attribute::Class,
148                EntryClass::OAuth2ResourceServerBasic.to_value()
149            ),
150            (Attribute::Uuid, Value::Uuid(uuid)),
151            (
152                Attribute::DisplayName,
153                Value::new_utf8s("test_resource_server")
154            ),
155            (Attribute::Name, Value::new_iname("test_resource_server")),
156            (
157                Attribute::OAuth2RsOriginLanding,
158                Value::new_url_s("https://demo.example.com").unwrap()
159            ),
160            (
161                Attribute::OAuth2RsScopeMap,
162                Value::new_oauthscopemap(
163                    UUID_IDM_ALL_ACCOUNTS,
164                    btreeset![OAUTH2_SCOPE_READ.to_string()]
165                )
166                .expect("invalid oauthscope")
167            )
168        );
169
170        let create = vec![e];
171
172        run_create_test!(
173            Ok(None),
174            preload,
175            create,
176            None,
177            |qs: &mut QueryServerWriteTransaction| {
178                let e = qs
179                    .internal_search_uuid(uuid)
180                    .expect("failed to get oauth2 config");
181                assert!(e.attribute_pres(Attribute::OAuth2RsBasicSecret));
182            }
183        );
184    }
185
186    #[test]
187    fn test_modify_oauth2_secrets_regenerate() {
188        let uuid = Uuid::new_v4();
189
190        let e: Entry<EntryInit, EntryNew> = entry_init!(
191            (Attribute::Class, EntryClass::Object.to_value()),
192            (Attribute::Class, EntryClass::Account.to_value()),
193            (
194                Attribute::Class,
195                EntryClass::OAuth2ResourceServer.to_value()
196            ),
197            (
198                Attribute::Class,
199                EntryClass::OAuth2ResourceServerBasic.to_value()
200            ),
201            (Attribute::Uuid, Value::Uuid(uuid)),
202            (Attribute::Name, Value::new_iname("test_resource_server")),
203            (
204                Attribute::DisplayName,
205                Value::new_utf8s("test_resource_server")
206            ),
207            (
208                Attribute::OAuth2RsOriginLanding,
209                Value::new_url_s("https://demo.example.com").unwrap()
210            ),
211            (
212                Attribute::OAuth2RsScopeMap,
213                Value::new_oauthscopemap(
214                    UUID_IDM_ALL_ACCOUNTS,
215                    btreeset![OAUTH2_SCOPE_READ.to_string()]
216                )
217                .expect("invalid oauthscope")
218            ),
219            (
220                Attribute::OAuth2RsBasicSecret,
221                Value::new_secret_str("12345")
222            )
223        );
224
225        let preload = vec![e];
226
227        run_modify_test!(
228            Ok(()),
229            preload,
230            filter!(f_eq(Attribute::Uuid, PartialValue::Uuid(uuid))),
231            ModifyList::new_list(vec![
232                Modify::Purged(Attribute::OAuth2RsBasicSecret,),
233                Modify::Purged(Attribute::OAuth2RsTokenKey,)
234            ]),
235            None,
236            |_| {},
237            |qs: &mut QueryServerWriteTransaction| {
238                let e = qs
239                    .internal_search_uuid(uuid)
240                    .expect("failed to get oauth2 config");
241                assert!(e.attribute_pres(Attribute::OAuth2RsBasicSecret));
242                // Check the values are different.
243                assert!(e.get_ava_single_secret(Attribute::OAuth2RsBasicSecret) != Some("12345"));
244            }
245        );
246    }
247}