kanidmd_core/repl/
mod.rs

1use self::codec::{ConsumerRequest, SupplierResponse};
2use crate::CoreAction;
3use config::{RepNodeConfig, ReplicationConfiguration};
4use futures_util::sink::SinkExt;
5use futures_util::stream::StreamExt;
6use kanidmd_lib::prelude::duration_from_epoch_now;
7use kanidmd_lib::prelude::IdmServer;
8use kanidmd_lib::repl::proto::ConsumerState;
9use kanidmd_lib::server::QueryServerTransaction;
10use openssl::x509::X509;
11use rustls::{
12    client::ClientConfig,
13    pki_types::{CertificateDer, PrivateKeyDer, ServerName},
14    server::{ServerConfig, WebPkiClientVerifier},
15    RootCertStore,
16};
17use std::collections::VecDeque;
18use std::net::SocketAddr;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::broadcast;
22use tokio::sync::mpsc;
23use tokio::sync::oneshot;
24use tokio::sync::Mutex;
25use tokio::time::{interval, sleep, timeout};
26use tokio::{
27    net::{TcpListener, TcpStream},
28    task::JoinHandle,
29};
30use tokio_rustls::{client::TlsStream, TlsAcceptor, TlsConnector};
31use tokio_util::codec::{Framed, FramedRead, FramedWrite};
32use tracing::{error, Instrument};
33use url::Url;
34use uuid::Uuid;
35
36mod codec;
37pub(crate) mod config;
38
39pub(crate) enum ReplCtrl {
40    GetCertificate {
41        respond: oneshot::Sender<X509>,
42    },
43    RenewCertificate {
44        respond: oneshot::Sender<bool>,
45    },
46    RefreshConsumer {
47        respond: oneshot::Sender<mpsc::Receiver<()>>,
48    },
49}
50
51#[derive(Debug, Clone)]
52enum ReplConsumerCtrl {
53    Stop,
54    Refresh(Arc<Mutex<(bool, mpsc::Sender<()>)>>),
55}
56
57pub(crate) async fn create_repl_server(
58    idms: Arc<IdmServer>,
59    repl_config: &ReplicationConfiguration,
60    rx: broadcast::Receiver<CoreAction>,
61) -> Result<(tokio::task::JoinHandle<()>, mpsc::Sender<ReplCtrl>), ()> {
62    // We need to start the tcp listener. This will persist over ssl reloads!
63    let listener = TcpListener::bind(&repl_config.bindaddress)
64        .await
65        .map_err(|e| {
66            error!(
67                "Could not bind to replication address {} -> {:?}",
68                repl_config.bindaddress, e
69            );
70        })?;
71
72    // Create the control channel. Use a low msg count, there won't be that much going on.
73    let (ctrl_tx, ctrl_rx) = mpsc::channel(4);
74
75    // We need to start the tcp listener. This will persist over ssl reloads!
76    info!(
77        "Starting replication interface https://{} ...",
78        repl_config.bindaddress
79    );
80    let repl_handle: JoinHandle<()> = tokio::spawn(repl_acceptor(
81        listener,
82        idms,
83        repl_config.clone(),
84        rx,
85        ctrl_rx,
86    ));
87
88    info!("Created replication interface");
89    Ok((repl_handle, ctrl_tx))
90}
91
92#[instrument(level = "debug", skip_all)]
93/// This returns the remote address that worked, so you can try that first next time
94async fn repl_consumer_connect_supplier(
95    server_name: &ServerName<'static>,
96    sock_addrs: &[SocketAddr],
97    tls_connector: &TlsConnector,
98    consumer_conn_settings: &ConsumerConnSettings,
99) -> Option<(
100    SocketAddr,
101    Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
102)> {
103    // This is pretty gnarly, but we need to loop to try out each socket addr.
104    for sock_addr in sock_addrs {
105        debug!(
106            "Attempting to connect to {} replica via {}",
107            server_name.to_str(),
108            sock_addr
109        );
110
111        let tcpstream = match timeout(
112            consumer_conn_settings.replica_connect_timeout,
113            TcpStream::connect(sock_addr),
114        )
115        .await
116        {
117            Ok(Ok(tc)) => {
118                trace!("Connection established to peer on {:?}", sock_addr);
119                tc
120            }
121            Ok(Err(err)) => {
122                debug!(?err, "Failed to connect to {}", sock_addr);
123                continue;
124            }
125            Err(_) => {
126                debug!("Timeout connecting to {}", sock_addr);
127                continue;
128            }
129        };
130
131        let tlsstream = match tls_connector
132            .connect(server_name.to_owned(), tcpstream)
133            .await
134        {
135            Ok(ta) => ta,
136            Err(e) => {
137                error!("Replication client TLS setup error, continuing -> {:?}", e);
138                continue;
139            }
140        };
141
142        let supplier_conn = Framed::new(
143            tlsstream,
144            codec::ConsumerCodec::new(consumer_conn_settings.max_frame_bytes),
145        );
146        // "hey this one worked, try it first next time!"
147        return Some((sock_addr.to_owned(), supplier_conn));
148    }
149
150    error!(
151        "Unable to connect to supplier, tried to connect to {:?}",
152        sock_addrs
153    );
154    None
155}
156
157/// This returns the socket address that worked, so you can try that first next time
158#[instrument(level="info", skip(refresh_coord, tls_connector, idms), fields(uuid=Uuid::new_v4().to_string()))]
159async fn repl_run_consumer_refresh(
160    refresh_coord: Arc<Mutex<(bool, mpsc::Sender<()>)>>,
161    server_name: &ServerName<'static>,
162    sock_addrs: &[SocketAddr],
163    tls_connector: &TlsConnector,
164    idms: &IdmServer,
165    consumer_conn_settings: &ConsumerConnSettings,
166) -> Result<Option<SocketAddr>, ()> {
167    // Take the refresh lock. Note that every replication consumer *should* end up here
168    // behind this lock, but only one can proceed. This is what we want!
169
170    let mut refresh_coord_guard = refresh_coord.lock().await;
171
172    // Simple case - task is already done.
173    if refresh_coord_guard.0 {
174        trace!("Refresh already completed by another task, return.");
175        return Ok(None);
176    }
177
178    // okay, we need to proceed.
179    let (addr, mut supplier_conn) = repl_consumer_connect_supplier(
180        server_name,
181        sock_addrs,
182        tls_connector,
183        consumer_conn_settings,
184    )
185    .await
186    .ok_or(())?;
187
188    // If we fail at any point, just RETURN because this leaves the next task to attempt, or
189    // the channel drops and that tells the caller this failed.
190    supplier_conn
191        .send(ConsumerRequest::Refresh)
192        .await
193        .map_err(|err| error!(?err, "consumer encode error, unable to continue."))?;
194
195    let refresh = if let Some(codec_msg) = supplier_conn.next().await {
196        match codec_msg.map_err(|err| error!(?err, "Consumer decode error, unable to continue."))? {
197            SupplierResponse::Refresh(changes) => {
198                // Success - return to bypass the error message.
199                changes
200            }
201            SupplierResponse::Pong | SupplierResponse::Incremental(_) => {
202                error!("Supplier Response contains invalid State");
203                return Err(());
204            }
205        }
206    } else {
207        error!("Connection closed");
208        return Err(());
209    };
210
211    // Now apply the refresh if possible
212    {
213        // Scope the transaction.
214        let ct = duration_from_epoch_now();
215        idms.proxy_write(ct)
216            .await
217            .and_then(|mut write_txn| {
218                write_txn
219                    .qs_write
220                    .consumer_apply_refresh(refresh)
221                    .and_then(|cs| write_txn.commit().map(|()| cs))
222            })
223            .map_err(|err| error!(?err, "Consumer was not able to apply refresh."))?;
224    }
225
226    // Now mark the refresh as complete AND indicate it to the channel.
227    refresh_coord_guard.0 = true;
228    if refresh_coord_guard.1.send(()).await.is_err() {
229        warn!("Unable to signal to caller that refresh has completed.");
230    }
231
232    // Here the coord guard will drop and every other task proceeds.
233
234    info!("Replication refresh was successful.");
235    Ok(Some(addr))
236}
237
238#[instrument(level="debug", skip(tls_connector, idms), fields(eventid=Uuid::new_v4().to_string()))]
239async fn repl_run_consumer(
240    server_name: &ServerName<'static>,
241    sock_addrs: &[SocketAddr],
242    tls_connector: &TlsConnector,
243    automatic_refresh: bool,
244    idms: &IdmServer,
245    consumer_conn_settings: &ConsumerConnSettings,
246) -> Option<SocketAddr> {
247    let (socket_addr, mut supplier_conn) = repl_consumer_connect_supplier(
248        server_name,
249        sock_addrs,
250        tls_connector,
251        consumer_conn_settings,
252    )
253    .await?;
254
255    // Perform incremental.
256    let consumer_ruv_range = {
257        let consumer_state = idms
258            .proxy_read()
259            .await
260            .and_then(|mut read_txn| read_txn.qs_read.consumer_get_state());
261        match consumer_state {
262            Ok(ruv_range) => ruv_range,
263            Err(err) => {
264                error!(
265                    ?err,
266                    "consumer ruv range could not be accessed, unable to continue."
267                );
268                return None;
269            }
270        }
271    };
272
273    if let Err(err) = supplier_conn
274        .send(ConsumerRequest::Incremental(consumer_ruv_range))
275        .await
276    {
277        error!(?err, "consumer encode error, unable to continue.");
278        return None;
279    }
280
281    let changes = if let Some(codec_msg) = supplier_conn.next().await {
282        match codec_msg {
283            Ok(SupplierResponse::Incremental(changes)) => {
284                // Success - return to bypass the error message.
285                changes
286            }
287            Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Refresh(_)) => {
288                error!("Supplier Response contains invalid state");
289                return None;
290            }
291            Err(err) => {
292                error!(?err, "Consumer decode error, unable to continue.");
293                return None;
294            }
295        }
296    } else {
297        error!("Connection closed");
298        return None;
299    };
300
301    // Now apply the changes if possible
302    let consumer_state = {
303        let ct = duration_from_epoch_now();
304        match idms.proxy_write(ct).await.and_then(|mut write_txn| {
305            write_txn
306                .qs_write
307                .consumer_apply_changes(changes)
308                .and_then(|cs| write_txn.commit().map(|()| cs))
309        }) {
310            Ok(state) => state,
311            Err(err) => {
312                error!(?err, "Consumer was not able to apply changes.");
313                return None;
314            }
315        }
316    };
317
318    match consumer_state {
319        ConsumerState::Ok => {
320            info!("Incremental Replication Success");
321            // return to bypass the failure message.
322            return Some(socket_addr);
323        }
324        ConsumerState::RefreshRequired => {
325            if automatic_refresh {
326                warn!("Consumer is out of date and must be refreshed. This will happen *now*.");
327            } else {
328                error!("Consumer is out of date and must be refreshed. You must manually resolve this situation.");
329                return None;
330            };
331        }
332    }
333
334    if let Err(err) = supplier_conn.send(ConsumerRequest::Refresh).await {
335        error!(?err, "consumer encode error, unable to continue.");
336        return None;
337    }
338
339    let refresh = if let Some(codec_msg) = supplier_conn.next().await {
340        match codec_msg {
341            Ok(SupplierResponse::Refresh(changes)) => {
342                // Success - return to bypass the error message.
343                changes
344            }
345            Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Incremental(_)) => {
346                error!("Supplier Response contains invalid State");
347                return None;
348            }
349            Err(err) => {
350                error!(?err, "consumer decode error, unable to continue.");
351                return None;
352            }
353        }
354    } else {
355        error!("Connection closed");
356        return None;
357    };
358
359    // Now apply the refresh if possible
360    let ct = duration_from_epoch_now();
361    if let Err(err) = idms.proxy_write(ct).await.and_then(|mut write_txn| {
362        write_txn
363            .qs_write
364            .consumer_apply_refresh(refresh)
365            .and_then(|cs| write_txn.commit().map(|()| cs))
366    }) {
367        error!(?err, "consumer was not able to apply refresh.");
368        return None;
369    }
370
371    info!("Replication refresh was successful.");
372    Some(socket_addr)
373}
374
375#[derive(Debug, Clone)]
376struct ConsumerConnSettings {
377    max_frame_bytes: usize,
378    task_poll_interval: Duration,
379    replica_connect_timeout: Duration,
380}
381
382#[allow(clippy::too_many_arguments)]
383async fn repl_task(
384    origin: Url,
385
386    client_key: PrivateKeyDer<'static>,
387    client_cert: CertificateDer<'static>,
388    supplier_cert: CertificateDer<'static>,
389
390    consumer_conn_settings: ConsumerConnSettings,
391    mut task_rx: broadcast::Receiver<ReplConsumerCtrl>,
392    automatic_refresh: bool,
393    idms: Arc<IdmServer>,
394) {
395    if origin.scheme() != "repl" {
396        error!("Replica origin is not repl:// - refusing to proceed.");
397        return;
398    }
399
400    let domain = match origin.domain() {
401        Some(d) => d,
402        None => {
403            error!("Replica origin does not have a valid domain name, unable to proceed. Perhaps you tried to use an ip address?");
404            return;
405        }
406    };
407
408    let Ok(server_name) = ServerName::try_from(domain.to_owned()) else {
409        error!("Replica origin does not have a valid domain name, unable to proceed.");
410        return;
411    };
412
413    // Add the supplier cert.
414    // ⚠️  note that here we need to build a new cert store. This is because
415    // we want to pin a single certificate!
416    let mut root_cert_store = RootCertStore::empty();
417    if let Err(err) = root_cert_store.add(supplier_cert) {
418        error!(?err, "Replica supplier cert invalid.");
419        return;
420    };
421
422    let provider = rustls::crypto::aws_lc_rs::default_provider().into();
423
424    let tls_client_config = match ClientConfig::builder_with_provider(provider)
425        .with_safe_default_protocol_versions()
426        .and_then(|builder| {
427            builder
428                .with_root_certificates(root_cert_store)
429                .with_client_auth_cert(vec![client_cert], client_key)
430        }) {
431        Ok(ccb) => ccb,
432        Err(err) => {
433            error!(?err, "Unable to build TLS client configuration");
434            return;
435        }
436    };
437
438    let tls_connector = TlsConnector::from(Arc::new(tls_client_config));
439
440    let mut repl_interval = interval(consumer_conn_settings.task_poll_interval);
441
442    info!("Replica task for {} has started.", origin);
443
444    // we keep track of the "last known good" socketaddr so we can try that first next time.
445    let mut last_working_address: Option<SocketAddr> = None;
446
447    // Okay, all the parameters are set up. Now we replicate on our interval.
448    loop {
449        // we resolve the DNS entry to the ip:port each time we attempt a connection to avoid stale
450        // DNS issues, ref #3188. If we are unable to resolve the address, we backoff and try again
451        // as in something like docker the address may change frequently.
452        //
453        // Note, if DNS isn't available, we can proceed with the last used working address too. This
454        // prevents DNS (or lack thereof) from causing a replication outage.
455        let mut sorted_socket_addrs = vec![];
456
457        // If the target address worked last time, then let's use it this time!
458        if let Some(addr) = last_working_address {
459            debug!(?last_working_address);
460            sorted_socket_addrs.push(addr);
461        };
462
463        // Default to port 443 if not set in the origin
464        match origin.socket_addrs(|| Some(443)) {
465            Ok(mut socket_addrs) => {
466                // Make every address unique.
467                socket_addrs.sort_unstable();
468                socket_addrs.dedup();
469
470                // The only possible conflict is with the last working address,
471                // so lets just check that.
472                socket_addrs.into_iter().for_each(|addr| {
473                    if Some(&addr) != last_working_address.as_ref() {
474                        // Not already present, append
475                        sorted_socket_addrs.push(addr);
476                    }
477                });
478            }
479            Err(err) => {
480                if let Some(addr) = last_working_address {
481                    warn!(
482                        ?err,
483                        "Unable to resolve '{origin}' to ip:port, using last known working address '{addr}'"
484                    );
485                } else {
486                    warn!(?err, "Unable to resolve '{origin}' to ip:port.");
487                }
488            }
489        };
490
491        if sorted_socket_addrs.is_empty() {
492            warn!(
493                "No replication addresses available, delaying replication operation for '{origin}'"
494            );
495            repl_interval.tick().await;
496            continue;
497        }
498
499        tokio::select! {
500            Ok(task) = task_rx.recv() => {
501                match task {
502                    ReplConsumerCtrl::Stop => break,
503                    ReplConsumerCtrl::Refresh ( refresh_coord ) => {
504                        last_working_address = (repl_run_consumer_refresh(
505                            refresh_coord,
506                            &server_name,
507                            &sorted_socket_addrs,
508                            &tls_connector,
509                            &idms,
510                            &consumer_conn_settings
511                        )
512                        .await).unwrap_or_default();
513                    }
514                }
515            }
516            _ = repl_interval.tick() => {
517                // Interval passed, attempt a replication run.
518                repl_run_consumer(
519                    &server_name,
520                    &sorted_socket_addrs,
521                    &tls_connector,
522                    automatic_refresh,
523                    &idms,
524                    &consumer_conn_settings
525                )
526                .await;
527            }
528        }
529    }
530
531    info!("Replica task for {} has stopped.", origin);
532}
533
534#[instrument(level = "debug", skip_all)]
535async fn handle_repl_conn(
536    max_frame_bytes: usize,
537    tcpstream: TcpStream,
538    client_address: SocketAddr,
539    tls_acceptor: TlsAcceptor,
540    idms: Arc<IdmServer>,
541) {
542    debug!(?client_address, "replication client connected 🛫");
543
544    let tlsstream = match tls_acceptor.accept(tcpstream).await {
545        Ok(ta) => ta,
546        Err(err) => {
547            error!(?err, "Replication TLS setup error, disconnecting client");
548            return;
549        }
550    };
551
552    let (r, w) = tokio::io::split(tlsstream);
553    let mut r = FramedRead::new(r, codec::SupplierCodec::new(max_frame_bytes));
554    let mut w = FramedWrite::new(w, codec::SupplierCodec::new(max_frame_bytes));
555
556    while let Some(codec_msg) = r.next().await {
557        match codec_msg {
558            Ok(ConsumerRequest::Ping) => {
559                debug!("consumer requested ping");
560                if let Err(err) = w.send(SupplierResponse::Pong).await {
561                    error!(?err, "supplier encode error, unable to continue.");
562                    break;
563                }
564            }
565            Ok(ConsumerRequest::Incremental(consumer_ruv_range)) => {
566                let changes = match idms.proxy_read().await.and_then(|mut read_txn| {
567                    read_txn
568                        .qs_read
569                        .supplier_provide_changes(consumer_ruv_range)
570                }) {
571                    Ok(changes) => changes,
572                    Err(err) => {
573                        error!(?err, "supplier provide changes failed.");
574                        break;
575                    }
576                };
577
578                if let Err(err) = w.send(SupplierResponse::Incremental(changes)).await {
579                    error!(?err, "supplier encode error, unable to continue.");
580                    break;
581                }
582            }
583            Ok(ConsumerRequest::Refresh) => {
584                let changes = match idms
585                    .proxy_read()
586                    .await
587                    .and_then(|mut read_txn| read_txn.qs_read.supplier_provide_refresh())
588                {
589                    Ok(changes) => changes,
590                    Err(err) => {
591                        error!(?err, "supplier provide refresh failed.");
592                        break;
593                    }
594                };
595
596                if let Err(err) = w.send(SupplierResponse::Refresh(changes)).await {
597                    error!(?err, "supplier encode error, unable to continue.");
598                    break;
599                }
600            }
601            Err(err) => {
602                error!(?err, "supplier decode error, unable to continue.");
603                break;
604            }
605        }
606    }
607
608    debug!(?client_address, "replication client disconnected 🛬");
609}
610
611/// This is the main acceptor for the replication server.
612async fn repl_acceptor(
613    listener: TcpListener,
614    idms: Arc<IdmServer>,
615    repl_config: ReplicationConfiguration,
616    mut rx: broadcast::Receiver<CoreAction>,
617    mut ctrl_rx: mpsc::Receiver<ReplCtrl>,
618) {
619    info!("Starting Replication Acceptor ...");
620    // Persistent parts
621    // These all probably need changes later ...
622    let replica_connect_timeout = Duration::from_secs(2);
623    let retry_timeout = Duration::from_secs(60);
624    let max_frame_bytes = 268435456;
625
626    let consumer_conn_settings = ConsumerConnSettings {
627        max_frame_bytes,
628        task_poll_interval: repl_config.get_task_poll_interval(),
629        replica_connect_timeout,
630    };
631
632    // Setup a broadcast to control our tasks.
633    let (task_tx, task_rx1) = broadcast::channel(1);
634    // Note, we drop this task here since each task will re-subscribe. That way the
635    // broadcast doesn't jam up because we aren't draining this task.
636    drop(task_rx1);
637    let mut task_handles = VecDeque::new();
638
639    // Create another broadcast to control the replication tasks and their need to reload.
640
641    // Spawn a KRC communication task?
642
643    // In future we need to update this from the KRC if configured, and we default this
644    // to "empty". But if this map exists in the config, we have to always use that.
645    let replication_node_map = repl_config.manual.clone();
646    let domain_name = match repl_config.origin.domain() {
647        Some(n) => n.to_string(),
648        None => {
649            error!("Unable to start replication, replication origin does not contain a valid domain name.");
650            return;
651        }
652    };
653
654    // This needs to have an event loop that can respond to changes.
655    // For now we just design it to reload ssl if the map changes internally.
656    'event: loop {
657        info!("Starting replication reload ...");
658        // Tell existing tasks to shutdown.
659        // Note: We ignore the result here since an err can occur *if* there are
660        // no tasks currently listening on the channel.
661        info!("Stopping {} Replication Tasks ...", task_handles.len());
662        debug_assert!(task_handles.len() >= task_tx.receiver_count());
663        let _ = task_tx.send(ReplConsumerCtrl::Stop);
664        for task_handle in task_handles.drain(..) {
665            // Let each task join.
666            let res: Result<(), _> = task_handle.await;
667            if res.is_err() {
668                warn!("Failed to join replication task, continuing ...");
669            }
670        }
671
672        // Now we can start to re-load configurations and setup our client tasks
673        // as well.
674
675        // Get our private key / cert.
676        let res = {
677            let ct = duration_from_epoch_now();
678            idms.proxy_write(ct).await.and_then(|mut idms_prox_write| {
679                idms_prox_write
680                    .qs_write
681                    .supplier_get_key_cert(&domain_name)
682                    .and_then(|res| idms_prox_write.commit().map(|()| res))
683            })
684        };
685
686        let (server_key, server_cert) = match res {
687            Ok(r) => r,
688            Err(err) => {
689                error!(?err, "CRITICAL: Unable to access supplier certificate/key.");
690                sleep(retry_timeout).await;
691                continue;
692            }
693        };
694
695        info!(
696            replication_cert_not_before = ?server_cert.not_before(),
697            replication_cert_not_after = ?server_cert.not_after(),
698        );
699
700        // rustls expects these to be der
701        let Ok(server_key_der) = server_key.private_key_to_der() else {
702            error!("CRITICAL: Unable to convert server key to DER.");
703            sleep(retry_timeout).await;
704            continue;
705        };
706
707        let Ok(server_key_der) = PrivateKeyDer::try_from(server_key_der) else {
708            error!("CRITICAL: Unable to convert server key from DER.");
709            sleep(retry_timeout).await;
710            continue;
711        };
712
713        let Ok(server_cert_der) = server_cert.to_der().map(CertificateDer::from) else {
714            error!("CRITICAL: Unable to convert server cert to DER.");
715            sleep(retry_timeout).await;
716            continue;
717        };
718
719        let mut client_certs = Vec::new();
720
721        // For each node in the map, either spawn a task to pull from that node,
722        // or setup the node as allowed to pull from us.
723        for (origin, node) in replication_node_map.iter() {
724            // Setup client certs
725            match node {
726                RepNodeConfig::MutualPull {
727                    partner_cert: consumer_cert,
728                    automatic_refresh: _,
729                }
730                | RepNodeConfig::AllowPull { consumer_cert } => {
731                    let Ok(consumer_cert_der) = consumer_cert.to_der().map(CertificateDer::from)
732                    else {
733                        warn!("WARNING: Unable to convert client cert to DER.");
734                        continue;
735                    };
736
737                    client_certs.push(consumer_cert_der)
738                }
739                RepNodeConfig::Pull {
740                    supplier_cert: _,
741                    automatic_refresh: _,
742                } => {}
743            };
744
745            match node {
746                RepNodeConfig::MutualPull {
747                    partner_cert: supplier_cert,
748                    automatic_refresh,
749                }
750                | RepNodeConfig::Pull {
751                    supplier_cert,
752                    automatic_refresh,
753                } => {
754                    let Ok(supplier_cert_der) = supplier_cert.to_der().map(CertificateDer::from)
755                    else {
756                        warn!("WARNING: Unable to convert client cert to DER.");
757                        continue;
758                    };
759
760                    let task_rx = task_tx.subscribe();
761
762                    let handle: JoinHandle<()> = tokio::spawn(repl_task(
763                        origin.clone(),
764                        server_key_der.clone_key(),
765                        server_cert_der.clone(),
766                        supplier_cert_der.clone(),
767                        consumer_conn_settings.clone(),
768                        task_rx,
769                        *automatic_refresh,
770                        idms.clone(),
771                    ));
772
773                    task_handles.push_back(handle);
774                    debug_assert_eq!(task_handles.len(), task_tx.receiver_count());
775                }
776                RepNodeConfig::AllowPull { consumer_cert: _ } => {}
777            };
778        }
779
780        // ⚠️  This section is critical to the security of replication
781        //    Since replication relies on mTLS we MUST ensure these options
782        //    are absolutely correct!
783        //
784        // Setup the TLS builder.
785
786        // ⚠️  CRITICAL - ensure that the cert store only has client certs from
787        // the repl map added.
788        let mut client_cert_roots = RootCertStore::empty();
789
790        for client_cert in client_certs.into_iter() {
791            if let Err(err) = client_cert_roots.add(client_cert) {
792                error!(?err, "CRITICAL, unable to add client certificate.");
793                sleep(retry_timeout).await;
794                continue;
795            }
796        }
797
798        let provider: Arc<_> = rustls::crypto::aws_lc_rs::default_provider().into();
799
800        let client_cert_verifier_result =
801            WebPkiClientVerifier::builder_with_provider(client_cert_roots.into(), provider.clone())
802                // We don't allow clients that lack a certificate to correct.
803                // allow_unauthenticated()
804                .build();
805
806        let client_cert_verifier = match client_cert_verifier_result {
807            Ok(ccv) => ccv,
808            Err(err) => {
809                error!(
810                    ?err,
811                    "CRITICAL, unable to configure client certificate verifier."
812                );
813                sleep(retry_timeout).await;
814                continue;
815            }
816        };
817
818        let tls_server_config = match ServerConfig::builder_with_provider(provider)
819            .with_safe_default_protocol_versions()
820            .and_then(|builder| {
821                builder
822                    .with_client_cert_verifier(client_cert_verifier)
823                    .with_single_cert(vec![server_cert_der], server_key_der)
824            }) {
825            Ok(tls_server_config) => tls_server_config,
826            Err(err) => {
827                error!(
828                    ?err,
829                    "CRITICAL, unable to create TLS Server Config. Will retry ..."
830                );
831                sleep(retry_timeout).await;
832                continue;
833            }
834        };
835
836        let tls_acceptor = TlsAcceptor::from(Arc::new(tls_server_config));
837
838        loop {
839            // This is great to diagnose when spans are entered or present and they capture
840            // things incorrectly.
841            // eprintln!("🔥 C ---> {:?}", tracing::Span::current());
842            let eventid = Uuid::new_v4();
843
844            tokio::select! {
845                Ok(action) = rx.recv() => {
846                    match action {
847                        CoreAction::Shutdown => break 'event,
848                    }
849                }
850                Some(ctrl_msg) = ctrl_rx.recv() => {
851                    match ctrl_msg {
852                        ReplCtrl::GetCertificate {
853                            respond
854                        } => {
855                            let _span = debug_span!("supplier_accept_loop", uuid = ?eventid).entered();
856                            if respond.send(server_cert.clone()).is_err() {
857                                warn!("Server certificate was requested, but requsetor disconnected");
858                            } else {
859                                trace!("Sent server certificate via control channel");
860                            }
861                        }
862                        ReplCtrl::RenewCertificate {
863                            respond
864                        } => {
865                            let span = debug_span!("supplier_accept_loop", uuid = ?eventid);
866                            async {
867                                debug!("renewing replication certificate ...");
868                                // Renew the cert.
869                                let res = {
870                                    let ct = duration_from_epoch_now();
871                                    idms.proxy_write(ct).await
872                                        .and_then(|mut idms_prox_write|
873                                    idms_prox_write
874                                        .qs_write
875                                        .supplier_renew_key_cert(&domain_name)
876                                        .and_then(|res| idms_prox_write.commit().map(|()| res))
877                                        )
878                                };
879
880                                let success = res.is_ok();
881
882                                if let Err(err) = res {
883                                    error!(?err, "failed to renew server certificate");
884                                }
885
886                                if respond.send(success).is_err() {
887                                    warn!("Server certificate renewal was requested, but requester disconnected!");
888                                } else {
889                                    trace!("Sent server certificate renewal status via control channel");
890                                }
891                            }
892                            .instrument(span)
893                            .await;
894
895                            // Start a reload.
896                            continue 'event;
897                        }
898                        ReplCtrl::RefreshConsumer {
899                            respond
900                        } => {
901                            // Indicate to consumer tasks that they should do a refresh.
902                            let (tx, rx) = mpsc::channel(1);
903
904                            let refresh_coord = Arc::new(
905                                Mutex::new(
906                                (
907                                    false, tx
908                                )
909                                )
910                            );
911
912                            if task_tx.send(ReplConsumerCtrl::Refresh(refresh_coord)).is_err() {
913                                error!("Unable to begin replication consumer refresh, tasks are unable to be notified.");
914                            }
915
916                            if respond.send(rx).is_err() {
917                                warn!("Replication consumer refresh was requested, but requester disconnected");
918                            } else {
919                                trace!("Sent refresh comms channel to requester");
920                            }
921                        }
922                    }
923                }
924                // Handle accepts.
925                // Handle *reloads*
926                /*
927                _ = reload.recv() => {
928                    info!("Initiating TLS reload");
929                    continue
930                }
931                */
932                accept_result = listener.accept() => {
933                    match accept_result {
934                        Ok((tcpstream, client_socket_addr)) => {
935                            let clone_idms = idms.clone();
936                            let clone_tls_acceptor = tls_acceptor.clone();
937                            // We don't care about the join handle here - once a client connects
938                            // it sticks to whatever ssl settings it had at launch.
939                            tokio::spawn(
940                                handle_repl_conn(max_frame_bytes, tcpstream, client_socket_addr, clone_tls_acceptor, clone_idms)
941                            );
942                        }
943                        Err(e) => {
944                            error!("replication acceptor error, continuing -> {:?}", e);
945                        }
946                    }
947                }
948            } // end select
949              // Continue to poll/loop
950        }
951    }
952    // Shutdown child tasks.
953    info!("Stopping {} Replication Tasks ...", task_handles.len());
954    debug_assert!(task_handles.len() >= task_tx.receiver_count());
955    let _ = task_tx.send(ReplConsumerCtrl::Stop);
956    for task_handle in task_handles.drain(..) {
957        // Let each task join.
958        let res: Result<(), _> = task_handle.await.map(|_| ());
959        if res.is_err() {
960            warn!("Failed to join replication task, continuing ...");
961        }
962    }
963
964    info!("Stopped {}", super::TaskName::Replication);
965}