kanidmd_core/repl/
mod.rs

1use self::codec::{ConsumerRequest, SupplierResponse};
2use crate::CoreAction;
3use config::{RepNodeConfig, ReplicationConfiguration};
4use futures_util::sink::SinkExt;
5use futures_util::stream::StreamExt;
6use kanidmd_lib::prelude::duration_from_epoch_now;
7use kanidmd_lib::prelude::IdmServer;
8use kanidmd_lib::repl::proto::ConsumerState;
9use kanidmd_lib::server::QueryServerTransaction;
10use openssl::x509::X509;
11use rustls::{
12    client::ClientConfig,
13    pki_types::{CertificateDer, PrivateKeyDer, ServerName},
14    server::{ServerConfig, WebPkiClientVerifier},
15    RootCertStore,
16};
17use std::collections::VecDeque;
18use std::net::SocketAddr;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::io::AsyncWriteExt;
22use tokio::sync::broadcast;
23use tokio::sync::mpsc;
24use tokio::sync::oneshot;
25use tokio::sync::{Mutex, MutexGuard};
26use tokio::time::{interval, sleep, timeout};
27use tokio::{
28    net::{TcpListener, TcpStream},
29    task::JoinHandle,
30};
31use tokio_rustls::{client::TlsStream, TlsAcceptor, TlsConnector};
32use tokio_util::codec::{Framed, FramedRead, FramedWrite};
33use tracing::{error, Instrument};
34use url::Url;
35use uuid::Uuid;
36
37mod codec;
38pub(crate) mod config;
39
40pub(crate) enum ReplCtrl {
41    GetCertificate {
42        respond: oneshot::Sender<X509>,
43    },
44    RenewCertificate {
45        respond: oneshot::Sender<bool>,
46    },
47    RefreshConsumer {
48        respond: oneshot::Sender<mpsc::Receiver<()>>,
49    },
50}
51
52#[derive(Debug, Clone)]
53enum ReplConsumerCtrl {
54    Stop,
55    Refresh(Arc<Mutex<(bool, mpsc::Sender<()>)>>),
56}
57
58pub(crate) async fn create_repl_server(
59    idms: Arc<IdmServer>,
60    repl_config: &ReplicationConfiguration,
61    rx: broadcast::Receiver<CoreAction>,
62) -> Result<(tokio::task::JoinHandle<()>, mpsc::Sender<ReplCtrl>), ()> {
63    // We need to start the tcp listener. This will persist over ssl reloads!
64    let listener = TcpListener::bind(&repl_config.bindaddress)
65        .await
66        .map_err(|e| {
67            error!(
68                "Could not bind to replication address {} -> {:?}",
69                repl_config.bindaddress, e
70            );
71        })?;
72
73    // Create the control channel. Use a low msg count, there won't be that much going on.
74    let (ctrl_tx, ctrl_rx) = mpsc::channel(4);
75
76    // We need to start the tcp listener. This will persist over ssl reloads!
77    info!(
78        "Starting replication interface https://{} ...",
79        repl_config.bindaddress
80    );
81    let repl_handle: JoinHandle<()> = tokio::spawn(repl_acceptor(
82        listener,
83        idms,
84        repl_config.clone(),
85        rx,
86        ctrl_rx,
87    ));
88
89    info!("Created replication interface");
90    Ok((repl_handle, ctrl_tx))
91}
92
93#[instrument(level = "debug", skip_all)]
94/// This returns the remote address that worked, so you can try that first next time
95async fn repl_consumer_connect_supplier(
96    server_name: &ServerName<'static>,
97    sock_addrs: &[SocketAddr],
98    tls_connector: &TlsConnector,
99    consumer_conn_settings: &ConsumerConnSettings,
100) -> Option<(
101    SocketAddr,
102    Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
103)> {
104    // This is pretty gnarly, but we need to loop to try out each socket addr.
105    for sock_addr in sock_addrs {
106        debug!(
107            "Attempting to connect to {} replica via {}",
108            server_name.to_str(),
109            sock_addr
110        );
111
112        let tcpstream = match timeout(
113            consumer_conn_settings.replica_connect_timeout,
114            TcpStream::connect(sock_addr),
115        )
116        .await
117        {
118            Ok(Ok(tc)) => {
119                trace!("Connection established to peer on {:?}", sock_addr);
120                tc
121            }
122            Ok(Err(err)) => {
123                debug!(?err, "Failed to connect to {}", sock_addr);
124                continue;
125            }
126            Err(_) => {
127                debug!("Timeout connecting to {}", sock_addr);
128                continue;
129            }
130        };
131
132        let tlsstream = match timeout(
133            consumer_conn_settings.replica_connect_timeout,
134            tls_connector.connect(server_name.to_owned(), tcpstream),
135        )
136        .await
137        {
138            Ok(Ok(ta)) => ta,
139            Ok(Err(e)) => {
140                error!("Replication client TLS setup error, continuing -> {:?}", e);
141                continue;
142            }
143            Err(_) => {
144                debug!("Timeout establishing TLS to {}", sock_addr);
145                continue;
146            }
147        };
148
149        let supplier_conn = Framed::new(
150            tlsstream,
151            codec::ConsumerCodec::new(consumer_conn_settings.max_frame_bytes),
152        );
153        // "hey this one worked, try it first next time!"
154        return Some((sock_addr.to_owned(), supplier_conn));
155    }
156
157    error!(
158        "Unable to connect to supplier, tried to connect to {:?}",
159        sock_addrs
160    );
161    None
162}
163
164async fn repl_consumer_disconnect_supplier(
165    supplier_conn: Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
166    consumer_conn_settings: &ConsumerConnSettings,
167) {
168    let mut tls_stream = supplier_conn.into_inner();
169
170    match timeout(
171        consumer_conn_settings.replica_connect_timeout,
172        tls_stream.shutdown(),
173    )
174    .await
175    {
176        Ok(Ok(_)) => debug!("Connection closed successfully"),
177        Ok(Err(tls_err)) => warn!(?tls_err, "Unable to cleanly shutdown TLS client connection"),
178        Err(_) => error!("Timeout attempting to close connection"),
179    }
180}
181
182/// This returns the socket address that worked, so you can try that first next time
183#[instrument(
184    level="info",
185    skip(refresh_coord, tls_connector, idms, consumer_conn_settings),
186    fields(eventid = Uuid::new_v4().to_string(), server_name = %server_name.to_str())
187)]
188async fn repl_run_consumer_refresh(
189    refresh_coord: Arc<Mutex<(bool, mpsc::Sender<()>)>>,
190    server_name: &ServerName<'static>,
191    sock_addrs: &[SocketAddr],
192    tls_connector: &TlsConnector,
193    idms: &IdmServer,
194    consumer_conn_settings: &ConsumerConnSettings,
195) -> Result<Option<SocketAddr>, ()> {
196    // Take the refresh lock. Note that every replication consumer *should* end up here
197    // behind this lock, but only one can proceed. This is what we want!
198
199    let refresh_coord_guard = refresh_coord.lock().await;
200
201    // Simple case - task is already done.
202    if refresh_coord_guard.0 {
203        trace!("Refresh already completed by another task, return.");
204        return Ok(None);
205    }
206
207    // Okay, we need to proceed. Open the connection.
208    let (addr, mut supplier_conn) = repl_consumer_connect_supplier(
209        server_name,
210        sock_addrs,
211        tls_connector,
212        consumer_conn_settings,
213    )
214    .await
215    .ok_or(())?;
216
217    let result = repl_run_consumer_refresh_inner(
218        addr,
219        &mut supplier_conn,
220        refresh_coord_guard,
221        idms,
222        consumer_conn_settings,
223    )
224    .await;
225
226    // disconnect the connection if possible.
227    repl_consumer_disconnect_supplier(supplier_conn, consumer_conn_settings).await;
228
229    result
230}
231
232async fn repl_run_consumer_refresh_inner(
233    addr: SocketAddr,
234    supplier_conn: &mut Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
235    mut refresh_coord_guard: MutexGuard<'_, (bool, mpsc::Sender<()>)>,
236    idms: &IdmServer,
237    consumer_conn_settings: &ConsumerConnSettings,
238) -> Result<Option<SocketAddr>, ()> {
239    // If we fail at any point, just RETURN because this leaves the next task to attempt, or
240    // the channel drops and that tells the caller this failed.
241
242    match timeout(
243        consumer_conn_settings.replica_connect_timeout,
244        supplier_conn.send(ConsumerRequest::Refresh),
245    )
246    .await
247    {
248        Ok(Ok(())) => {}
249        Ok(Err(err)) => {
250            error!(?err, "consumer encode error, unable to continue.");
251            return Err(());
252        }
253        Err(_) => {
254            error!("consumer request timeout error, unable to continue.");
255            return Err(());
256        }
257    };
258
259    let refresh = match timeout(
260        consumer_conn_settings.replica_connect_timeout,
261        supplier_conn.next(),
262    )
263    .await
264    {
265        Ok(Some(Ok(SupplierResponse::Refresh(changes)))) => {
266            // Success - return to bypass the error message.
267            changes
268        }
269        Ok(Some(Ok(SupplierResponse::Pong))) | Ok(Some(Ok(SupplierResponse::Incremental(_)))) => {
270            error!("Supplier Response contains invalid State");
271            return Err(());
272        }
273        Ok(Some(Err(codec_err))) => {
274            error!(?codec_err, "Consumer decode error, unable to continue.");
275            return Err(());
276        }
277        Ok(None) => {
278            error!("Connection closed");
279            return Err(());
280        }
281        Err(_) => {
282            error!("consumer response timeout error, unable to continue.");
283            return Err(());
284        }
285    };
286
287    // Now apply the refresh if possible
288    {
289        // Scope the transaction.
290        let ct = duration_from_epoch_now();
291        idms.proxy_write(ct)
292            .await
293            .and_then(|mut write_txn| {
294                write_txn
295                    .qs_write
296                    .consumer_apply_refresh(refresh)
297                    .and_then(|cs| write_txn.commit().map(|()| cs))
298            })
299            .map_err(|err| error!(?err, "Consumer was not able to apply refresh."))?;
300    }
301
302    // Now mark the refresh as complete AND indicate it to the channel.
303    refresh_coord_guard.0 = true;
304    if refresh_coord_guard.1.send(()).await.is_err() {
305        warn!("Unable to signal to caller that refresh has completed.");
306    }
307
308    // Here the coord guard will drop and every other task proceeds.
309
310    info!("Replication refresh was successful.");
311    Ok(Some(addr))
312}
313
314#[instrument(
315    level="info",
316    skip(tls_connector, idms, consumer_conn_settings, server_name),
317    fields(eventid = Uuid::new_v4().to_string(), server_name = %server_name.to_str())
318)]
319async fn repl_run_consumer(
320    server_name: &ServerName<'static>,
321    sock_addrs: &[SocketAddr],
322    tls_connector: &TlsConnector,
323    automatic_refresh: bool,
324    idms: &IdmServer,
325    consumer_conn_settings: &ConsumerConnSettings,
326    task_tx: &mut broadcast::Sender<ReplConsumerCtrl>,
327) -> Option<SocketAddr> {
328    let (socket_addr, mut supplier_conn) = repl_consumer_connect_supplier(
329        server_name,
330        sock_addrs,
331        tls_connector,
332        consumer_conn_settings,
333    )
334    .await?;
335
336    let result = repl_run_consumer_inner(
337        socket_addr,
338        &mut supplier_conn,
339        idms,
340        automatic_refresh,
341        consumer_conn_settings,
342        task_tx,
343    )
344    .await;
345
346    repl_consumer_disconnect_supplier(supplier_conn, consumer_conn_settings).await;
347
348    result
349}
350
351async fn repl_run_consumer_inner(
352    socket_addr: SocketAddr,
353    supplier_conn: &mut Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
354    idms: &IdmServer,
355    automatic_refresh: bool,
356    consumer_conn_settings: &ConsumerConnSettings,
357    task_tx: &mut broadcast::Sender<ReplConsumerCtrl>,
358) -> Option<SocketAddr> {
359    // Perform incremental.
360    let consumer_ruv_range = {
361        let consumer_state = idms
362            .proxy_read()
363            .await
364            .and_then(|mut read_txn| read_txn.qs_read.consumer_get_state());
365        match consumer_state {
366            Ok(ruv_range) => ruv_range,
367            Err(err) => {
368                error!(
369                    ?err,
370                    "consumer ruv range could not be accessed, unable to continue."
371                );
372                return None;
373            }
374        }
375    };
376
377    match timeout(
378        consumer_conn_settings.replica_connect_timeout,
379        supplier_conn.send(ConsumerRequest::Incremental(consumer_ruv_range)),
380    )
381    .await
382    {
383        Ok(Ok(())) => {}
384        Ok(Err(err)) => {
385            error!(?err, "consumer encode error, unable to continue.");
386            return None;
387        }
388        Err(_) => {
389            error!("consumer request timeout, unable to continue");
390            return None;
391        }
392    };
393
394    let changes = match timeout(
395        consumer_conn_settings.replica_connect_timeout,
396        supplier_conn.next(),
397    )
398    .await
399    {
400        Ok(Some(Ok(SupplierResponse::Incremental(changes)))) => changes,
401        Ok(Some(Ok(SupplierResponse::Pong))) | Ok(Some(Ok(SupplierResponse::Refresh(_)))) => {
402            error!("Supplier Response contains invalid state");
403            return None;
404        }
405        Ok(Some(Err(err))) => {
406            error!(?err, "Consumer decode error, unable to continue.");
407            return None;
408        }
409        Ok(None) => {
410            error!("Consumer connection closed, unable to continue.");
411            return None;
412        }
413
414        Err(_) => {
415            error!("consumer response timeout, unable to continue");
416            return None;
417        }
418    };
419
420    // Now apply the changes if possible
421    let consumer_state = {
422        let ct = duration_from_epoch_now();
423        match idms.proxy_write(ct).await.and_then(|mut write_txn| {
424            write_txn
425                .qs_write
426                .consumer_apply_changes(changes)
427                .and_then(|cs| write_txn.commit().map(|()| cs))
428        }) {
429            Ok(state) => state,
430            Err(err) => {
431                error!(?err, "Consumer was not able to apply changes.");
432                return None;
433            }
434        }
435    };
436
437    match consumer_state {
438        ConsumerState::Ok => {
439            info!("Incremental Replication Success");
440            // return to bypass the failure message.
441            Some(socket_addr)
442        }
443        ConsumerState::RefreshRequired => {
444            if automatic_refresh {
445                warn!("Consumer is out of date and must be refreshed. This will happen *now*.");
446
447                let (tx, mut rx) = mpsc::channel(1);
448
449                let refresh_coord = Arc::new(Mutex::new((false, tx)));
450
451                if task_tx
452                    .send(ReplConsumerCtrl::Refresh(refresh_coord))
453                    .is_err()
454                {
455                    error!("Unable to begin replication consumer refresh, tasks are unable to be notified.");
456                } else {
457                    tokio::spawn(async move {
458                        match rx.recv().await {
459                            Some(()) => {
460                                info!("replication consumer refresh complete!")
461                            }
462                            None => {
463                                warn!("refresh task was lost with no response!")
464                            }
465                        }
466                    });
467                }
468
469                None
470            } else {
471                error!("Consumer is out of date and must be refreshed. You must manually resolve this situation.");
472                None
473            }
474        }
475    }
476}
477
478#[derive(Debug, Clone)]
479struct ConsumerConnSettings {
480    max_frame_bytes: usize,
481    task_poll_interval: Duration,
482    replica_connect_timeout: Duration,
483}
484
485#[allow(clippy::too_many_arguments)]
486async fn repl_task(
487    origin: Url,
488
489    client_key: PrivateKeyDer<'static>,
490    client_cert: CertificateDer<'static>,
491    supplier_cert: CertificateDer<'static>,
492
493    consumer_conn_settings: ConsumerConnSettings,
494    mut task_rx: broadcast::Receiver<ReplConsumerCtrl>,
495    mut task_tx: broadcast::Sender<ReplConsumerCtrl>,
496    automatic_refresh: bool,
497    idms: Arc<IdmServer>,
498) {
499    if origin.scheme() != "repl" {
500        error!("Replica origin is not repl:// - refusing to proceed.");
501        return;
502    }
503
504    let domain = match origin.domain() {
505        Some(d) => d,
506        None => {
507            error!("Replica origin does not have a valid domain name, unable to proceed. Perhaps you tried to use an ip address?");
508            return;
509        }
510    };
511
512    let Ok(server_name) = ServerName::try_from(domain.to_owned()) else {
513        error!("Replica origin does not have a valid domain name, unable to proceed.");
514        return;
515    };
516
517    // Add the supplier cert.
518    // ⚠️  note that here we need to build a new cert store. This is because
519    // we want to pin a single certificate!
520    let mut root_cert_store = RootCertStore::empty();
521    if let Err(err) = root_cert_store.add(supplier_cert) {
522        error!(?err, "Replica supplier cert invalid.");
523        return;
524    };
525
526    let provider = rustls::crypto::aws_lc_rs::default_provider().into();
527
528    let tls_client_config = match ClientConfig::builder_with_provider(provider)
529        .with_safe_default_protocol_versions()
530        .and_then(|builder| {
531            builder
532                .with_root_certificates(root_cert_store)
533                .with_client_auth_cert(vec![client_cert], client_key)
534        }) {
535        Ok(ccb) => ccb,
536        Err(err) => {
537            error!(?err, "Unable to build TLS client configuration");
538            return;
539        }
540    };
541
542    let tls_connector = TlsConnector::from(Arc::new(tls_client_config));
543
544    let mut repl_interval = interval(consumer_conn_settings.task_poll_interval);
545
546    info!("Replica task for {} has started.", origin);
547
548    // we keep track of the "last known good" socketaddr so we can try that first next time.
549    let mut last_working_address: Option<SocketAddr> = None;
550
551    // Okay, all the parameters are set up. Now we replicate on our interval.
552    loop {
553        // we resolve the DNS entry to the ip:port each time we attempt a connection to avoid stale
554        // DNS issues, ref #3188. If we are unable to resolve the address, we backoff and try again
555        // as in something like docker the address may change frequently.
556        //
557        // Note, if DNS isn't available, we can proceed with the last used working address too. This
558        // prevents DNS (or lack thereof) from causing a replication outage.
559        let mut sorted_socket_addrs = vec![];
560
561        // If the target address worked last time, then let's use it this time!
562        if let Some(addr) = last_working_address {
563            debug!(?last_working_address);
564            sorted_socket_addrs.push(addr);
565        };
566
567        // Default to port 443 if not set in the origin
568        match origin.socket_addrs(|| Some(443)) {
569            Ok(mut socket_addrs) => {
570                // Make every address unique.
571                socket_addrs.sort_unstable();
572                socket_addrs.dedup();
573
574                // The only possible conflict is with the last working address,
575                // so lets just check that.
576                socket_addrs.into_iter().for_each(|addr| {
577                    if Some(&addr) != last_working_address.as_ref() {
578                        // Not already present, append
579                        sorted_socket_addrs.push(addr);
580                    }
581                });
582            }
583            Err(err) => {
584                if let Some(addr) = last_working_address {
585                    warn!(
586                        ?err,
587                        "Unable to resolve '{origin}' to ip:port, using last known working address '{addr}'"
588                    );
589                } else {
590                    warn!(?err, "Unable to resolve '{origin}' to ip:port.");
591                }
592            }
593        };
594
595        if sorted_socket_addrs.is_empty() {
596            warn!(
597                "No replication addresses available, delaying replication operation for '{origin}'"
598            );
599            repl_interval.tick().await;
600            continue;
601        }
602
603        tokio::select! {
604            Ok(task) = task_rx.recv() => {
605                match task {
606                    ReplConsumerCtrl::Stop => break,
607                    ReplConsumerCtrl::Refresh ( refresh_coord ) => {
608                        last_working_address = (repl_run_consumer_refresh(
609                            refresh_coord,
610                            &server_name,
611                            &sorted_socket_addrs,
612                            &tls_connector,
613                            &idms,
614                            &consumer_conn_settings
615                        )
616                        .await).unwrap_or_default();
617                    }
618                }
619            }
620            _ = repl_interval.tick() => {
621                // Interval passed, attempt a replication run.
622                repl_run_consumer(
623                    &server_name,
624                    &sorted_socket_addrs,
625                    &tls_connector,
626                    automatic_refresh,
627                    &idms,
628                    &consumer_conn_settings,
629                    &mut task_tx
630                )
631                .await;
632            }
633        }
634    }
635
636    info!("Replica task for {} has stopped.", origin);
637}
638
639#[instrument(level = "debug", skip_all)]
640async fn handle_repl_conn(
641    max_frame_bytes: usize,
642    tcpstream: TcpStream,
643    client_address: SocketAddr,
644    tls_acceptor: TlsAcceptor,
645    idms: Arc<IdmServer>,
646) {
647    debug!(?client_address, "replication client connected 🛫");
648
649    let tlsstream = match tls_acceptor.accept(tcpstream).await {
650        Ok(ta) => ta,
651        Err(err) => {
652            error!(?err, "Replication TLS setup error, disconnecting client");
653            return;
654        }
655    };
656
657    let (r, w) = tokio::io::split(tlsstream);
658    let mut r = FramedRead::new(r, codec::SupplierCodec::new(max_frame_bytes));
659    let mut w = FramedWrite::new(w, codec::SupplierCodec::new(max_frame_bytes));
660
661    while let Some(codec_msg) = r.next().await {
662        match codec_msg {
663            Ok(ConsumerRequest::Ping) => {
664                debug!("consumer requested ping");
665                if let Err(err) = w.send(SupplierResponse::Pong).await {
666                    error!(?err, "supplier encode error, unable to continue.");
667                    break;
668                }
669            }
670            Ok(ConsumerRequest::Incremental(consumer_ruv_range)) => {
671                let changes = match idms.proxy_read().await.and_then(|mut read_txn| {
672                    read_txn
673                        .qs_read
674                        .supplier_provide_changes(consumer_ruv_range)
675                }) {
676                    Ok(changes) => changes,
677                    Err(err) => {
678                        error!(?err, "supplier provide changes failed.");
679                        break;
680                    }
681                };
682
683                if let Err(err) = w.send(SupplierResponse::Incremental(changes)).await {
684                    error!(?err, "supplier encode error, unable to continue.");
685                    break;
686                }
687            }
688            Ok(ConsumerRequest::Refresh) => {
689                let changes = match idms
690                    .proxy_read()
691                    .await
692                    .and_then(|mut read_txn| read_txn.qs_read.supplier_provide_refresh())
693                {
694                    Ok(changes) => changes,
695                    Err(err) => {
696                        error!(?err, "supplier provide refresh failed.");
697                        break;
698                    }
699                };
700
701                if let Err(err) = w.send(SupplierResponse::Refresh(changes)).await {
702                    error!(?err, "supplier encode error, unable to continue.");
703                    break;
704                }
705            }
706            Err(err) => {
707                error!(?err, "supplier decode error, unable to continue.");
708                break;
709            }
710        }
711    }
712
713    debug!(?client_address, "replication client disconnected 🛬");
714}
715
716/// This is the main acceptor for the replication server.
717async fn repl_acceptor(
718    listener: TcpListener,
719    idms: Arc<IdmServer>,
720    repl_config: ReplicationConfiguration,
721    mut rx: broadcast::Receiver<CoreAction>,
722    mut ctrl_rx: mpsc::Receiver<ReplCtrl>,
723) {
724    info!("Starting Replication Acceptor ...");
725    // Persistent parts
726    // These all probably need changes later ...
727    let replica_connect_timeout = Duration::from_secs(5);
728    let mut retry_timeout = Duration::from_secs(1);
729    let max_frame_bytes = 268435456;
730
731    let consumer_conn_settings = ConsumerConnSettings {
732        max_frame_bytes,
733        task_poll_interval: repl_config.get_task_poll_interval(),
734        replica_connect_timeout,
735    };
736
737    // Setup a broadcast to control our tasks.
738    let (task_tx, task_rx1) = broadcast::channel(1);
739    // Note, we drop this task here since each task will re-subscribe. That way the
740    // broadcast doesn't jam up because we aren't draining this task.
741    drop(task_rx1);
742    let mut task_handles = VecDeque::new();
743
744    // Create another broadcast to control the replication tasks and their need to reload.
745
746    // Spawn a KRC communication task?
747
748    // In future we need to update this from the KRC if configured, and we default this
749    // to "empty". But if this map exists in the config, we have to always use that.
750    let replication_node_map = repl_config.manual.clone();
751    let domain_name = match repl_config.origin.domain() {
752        Some(n) => n.to_string(),
753        None => {
754            error!("Unable to start replication, replication origin does not contain a valid domain name.");
755            return;
756        }
757    };
758
759    // This needs to have an event loop that can respond to changes.
760    // For now we just design it to reload ssl if the map changes internally.
761    'event: loop {
762        // Don't block shutdowns while we are waiting here.
763        tokio::select! {
764            Ok(action) = rx.recv() => {
765                match action {
766                    CoreAction::Shutdown => break 'event,
767                }
768            }
769            _ = sleep(retry_timeout) => {}
770        }
771
772        // The timeout is initially small, we increase it here to prevent spinning too much.
773        retry_timeout = Duration::from_secs(60);
774
775        info!("Starting replication reload ...");
776        // Tell existing tasks to shutdown.
777        // Note: We ignore the result here since an err can occur *if* there are
778        // no tasks currently listening on the channel.
779        info!("Stopping {} Replication Tasks ...", task_handles.len());
780        debug_assert!(task_handles.len() >= task_tx.receiver_count());
781        let _ = task_tx.send(ReplConsumerCtrl::Stop);
782        for task_handle in task_handles.drain(..) {
783            // Let each task join.
784            let res: Result<(), _> = task_handle.await;
785            if res.is_err() {
786                warn!("Failed to join replication task, continuing ...");
787            }
788        }
789
790        // Now we can start to re-load configurations and setup our client tasks
791        // as well.
792
793        // Get our private key / cert.
794        let res = {
795            let ct = duration_from_epoch_now();
796            idms.proxy_write(ct).await.and_then(|mut idms_prox_write| {
797                idms_prox_write
798                    .qs_write
799                    .supplier_get_key_cert(&domain_name)
800                    .and_then(|res| idms_prox_write.commit().map(|()| res))
801            })
802        };
803
804        let (server_key, server_cert) = match res {
805            Ok(r) => r,
806            Err(err) => {
807                error!(?err, "CRITICAL: Unable to access supplier certificate/key.");
808                continue 'event;
809            }
810        };
811
812        info!(
813            replication_cert_not_before = ?server_cert.not_before(),
814            replication_cert_not_after = ?server_cert.not_after(),
815        );
816
817        // rustls expects these to be der
818        let Ok(server_key_der) = server_key.private_key_to_der() else {
819            error!("CRITICAL: Unable to convert server key to DER.");
820            continue 'event;
821        };
822
823        let Ok(server_key_der) = PrivateKeyDer::try_from(server_key_der) else {
824            error!("CRITICAL: Unable to convert server key from DER.");
825            continue 'event;
826        };
827
828        let Ok(server_cert_der) = server_cert.to_der().map(CertificateDer::from) else {
829            error!("CRITICAL: Unable to convert server cert to DER.");
830            continue 'event;
831        };
832
833        let mut client_certs = Vec::new();
834
835        // For each node in the map, either spawn a task to pull from that node,
836        // or setup the node as allowed to pull from us.
837        for (origin, node) in replication_node_map.iter() {
838            // Setup client certs
839            match node {
840                RepNodeConfig::MutualPull {
841                    partner_cert: consumer_cert,
842                    automatic_refresh: _,
843                }
844                | RepNodeConfig::AllowPull { consumer_cert } => {
845                    let Ok(consumer_cert_der) = consumer_cert.to_der().map(CertificateDer::from)
846                    else {
847                        warn!("WARNING: Unable to convert client cert to DER.");
848                        continue 'event;
849                    };
850
851                    client_certs.push(consumer_cert_der)
852                }
853                RepNodeConfig::Pull {
854                    supplier_cert: _,
855                    automatic_refresh: _,
856                } => {}
857            };
858
859            match node {
860                RepNodeConfig::MutualPull {
861                    partner_cert: supplier_cert,
862                    automatic_refresh,
863                }
864                | RepNodeConfig::Pull {
865                    supplier_cert,
866                    automatic_refresh,
867                } => {
868                    let Ok(supplier_cert_der) = supplier_cert.to_der().map(CertificateDer::from)
869                    else {
870                        warn!("WARNING: Unable to convert client cert to DER.");
871                        continue 'event;
872                    };
873
874                    let task_rx = task_tx.subscribe();
875                    let task_tx_c = task_tx.clone();
876
877                    let handle: JoinHandle<()> = tokio::spawn(repl_task(
878                        origin.clone(),
879                        server_key_der.clone_key(),
880                        server_cert_der.clone(),
881                        supplier_cert_der.clone(),
882                        consumer_conn_settings.clone(),
883                        task_rx,
884                        task_tx_c,
885                        *automatic_refresh,
886                        idms.clone(),
887                    ));
888
889                    task_handles.push_back(handle);
890                    debug_assert_eq!(task_handles.len(), task_tx.receiver_count());
891                }
892                RepNodeConfig::AllowPull { consumer_cert: _ } => {}
893            };
894        }
895
896        // ⚠️  This section is critical to the security of replication
897        //    Since replication relies on mTLS we MUST ensure these options
898        //    are absolutely correct!
899        //
900        // Setup the TLS builder.
901
902        // ⚠️  CRITICAL - ensure that the cert store only has client certs from
903        // the repl map added.
904
905        let tls_acceptor = if client_certs.is_empty() {
906            warn!("No replication client certs are available, replication connections will be ignored.");
907            None
908        } else {
909            let mut client_cert_roots = RootCertStore::empty();
910
911            for client_cert in client_certs.into_iter() {
912                if let Err(err) = client_cert_roots.add(client_cert) {
913                    error!(?err, "CRITICAL, unable to add client certificate.");
914                    continue 'event;
915                }
916            }
917
918            let provider: Arc<_> = rustls::crypto::aws_lc_rs::default_provider().into();
919
920            let client_cert_verifier_result = WebPkiClientVerifier::builder_with_provider(
921                client_cert_roots.into(),
922                provider.clone(),
923            )
924            // We don't allow clients that lack a certificate to correct.
925            // allow_unauthenticated()
926            .build();
927
928            let client_cert_verifier = match client_cert_verifier_result {
929                Ok(ccv) => ccv,
930                Err(err) => {
931                    error!(
932                        ?err,
933                        "CRITICAL, unable to configure client certificate verifier."
934                    );
935                    continue 'event;
936                }
937            };
938
939            let tls_server_config = match ServerConfig::builder_with_provider(provider)
940                .with_safe_default_protocol_versions()
941                .and_then(|builder| {
942                    builder
943                        .with_client_cert_verifier(client_cert_verifier)
944                        .with_single_cert(vec![server_cert_der], server_key_der)
945                }) {
946                Ok(tls_server_config) => tls_server_config,
947                Err(err) => {
948                    error!(
949                        ?err,
950                        "CRITICAL, unable to create TLS Server Config. Will retry ..."
951                    );
952                    continue 'event;
953                }
954            };
955
956            Some(TlsAcceptor::from(Arc::new(tls_server_config)))
957        };
958
959        loop {
960            // This is great to diagnose when spans are entered or present and they capture
961            // things incorrectly.
962            // eprintln!("🔥 C ---> {:?}", tracing::Span::current());
963            let eventid = Uuid::new_v4();
964
965            tokio::select! {
966                Ok(action) = rx.recv() => {
967                    match action {
968                        CoreAction::Shutdown => break 'event,
969                    }
970                }
971                Some(ctrl_msg) = ctrl_rx.recv() => {
972                    match ctrl_msg {
973                        ReplCtrl::GetCertificate {
974                            respond
975                        } => {
976                            let _span = debug_span!("supplier_accept_loop", uuid = ?eventid).entered();
977                            if respond.send(server_cert.clone()).is_err() {
978                                warn!("Server certificate was requested, but requsetor disconnected");
979                            } else {
980                                trace!("Sent server certificate via control channel");
981                            }
982                        }
983                        ReplCtrl::RenewCertificate {
984                            respond
985                        } => {
986                            let span = debug_span!("supplier_accept_loop", uuid = ?eventid);
987                            async {
988                                debug!("renewing replication certificate ...");
989                                // Renew the cert.
990                                let res = {
991                                    let ct = duration_from_epoch_now();
992                                    idms.proxy_write(ct).await
993                                        .and_then(|mut idms_prox_write|
994                                    idms_prox_write
995                                        .qs_write
996                                        .supplier_renew_key_cert(&domain_name)
997                                        .and_then(|res| idms_prox_write.commit().map(|()| res))
998                                        )
999                                };
1000
1001                                let success = res.is_ok();
1002
1003                                if let Err(err) = res {
1004                                    error!(?err, "failed to renew server certificate");
1005                                }
1006
1007                                if respond.send(success).is_err() {
1008                                    warn!("Server certificate renewal was requested, but requester disconnected!");
1009                                } else {
1010                                    trace!("Sent server certificate renewal status via control channel");
1011                                }
1012                            }
1013                            .instrument(span)
1014                            .await;
1015
1016                            // Start a reload.
1017                            continue 'event;
1018                        }
1019                        ReplCtrl::RefreshConsumer {
1020                            respond
1021                        } => {
1022                            // Indicate to consumer tasks that they should do a refresh.
1023                            let (tx, rx) = mpsc::channel(1);
1024
1025                            let refresh_coord = Arc::new(
1026                                Mutex::new(
1027                                (
1028                                    false, tx
1029                                )
1030                                )
1031                            );
1032
1033                            if task_tx.send(ReplConsumerCtrl::Refresh(refresh_coord)).is_err() {
1034                                error!("Unable to begin replication consumer refresh, tasks are unable to be notified.");
1035                            }
1036
1037                            if respond.send(rx).is_err() {
1038                                warn!("Replication consumer refresh was requested, but requester disconnected");
1039                            } else {
1040                                trace!("Sent refresh comms channel to requester");
1041                            }
1042                        }
1043                    }
1044                }
1045                // Handle accepts.
1046                // Handle *reloads*
1047                /*
1048                _ = reload.recv() => {
1049                    info!("Initiating TLS reload");
1050                    continue
1051                }
1052                */
1053                accept_result = listener.accept() => {
1054                    match accept_result {
1055                        Ok((tcpstream, client_socket_addr)) => {
1056                            if let Some(clone_tls_acceptor) = tls_acceptor.clone() {
1057                                let clone_idms = idms.clone();
1058                                // We don't care about the join handle here - once a client connects
1059                                // it sticks to whatever ssl settings it had at launch.
1060                                tokio::spawn(
1061                                    handle_repl_conn(max_frame_bytes, tcpstream, client_socket_addr, clone_tls_acceptor, clone_idms)
1062                                );
1063                            } else {
1064                                // TLS is not setup, generally due to no accepted/trusted client
1065                                // certs being present. Drop the connection.
1066                                warn!("Ignoring connection from {client_socket_addr} as replication is not configured correctly.");
1067                                warn!("This is because you have not configured this server with trusted partner certificates.");
1068                            }
1069                        }
1070                        Err(e) => {
1071                            error!("replication acceptor error, continuing -> {:?}", e);
1072                        }
1073                    }
1074                }
1075            } // end select
1076              // Continue to poll/loop
1077        }
1078    }
1079    // Shutdown child tasks.
1080    info!("Stopping {} Replication Tasks ...", task_handles.len());
1081    debug_assert!(task_handles.len() >= task_tx.receiver_count());
1082    let _ = task_tx.send(ReplConsumerCtrl::Stop);
1083    for task_handle in task_handles.drain(..) {
1084        // Let each task join.
1085        let res: Result<(), _> = task_handle.await.map(|_| ());
1086        if res.is_err() {
1087            warn!("Failed to join replication task, continuing ...");
1088        }
1089    }
1090
1091    info!("Stopped {}", super::TaskName::Replication);
1092}