kanidmd_core/repl/
mod.rs

1use self::codec::{ConsumerRequest, SupplierResponse};
2use crate::CoreAction;
3use config::{RepNodeConfig, ReplicationConfiguration};
4use futures_util::sink::SinkExt;
5use futures_util::stream::StreamExt;
6use kanidmd_lib::prelude::duration_from_epoch_now;
7use kanidmd_lib::prelude::IdmServer;
8use kanidmd_lib::repl::proto::ConsumerState;
9use kanidmd_lib::server::QueryServerTransaction;
10use openssl::x509::X509;
11use rustls::{
12    client::ClientConfig,
13    pki_types::{CertificateDer, PrivateKeyDer, ServerName},
14    server::{ServerConfig, WebPkiClientVerifier},
15    RootCertStore,
16};
17use std::collections::VecDeque;
18use std::net::SocketAddr;
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::io::AsyncWriteExt;
22use tokio::sync::broadcast;
23use tokio::sync::mpsc;
24use tokio::sync::oneshot;
25use tokio::sync::{Mutex, MutexGuard};
26use tokio::time::{interval, sleep, timeout};
27use tokio::{
28    net::{TcpListener, TcpStream},
29    task::JoinHandle,
30};
31use tokio_rustls::{client::TlsStream, TlsAcceptor, TlsConnector};
32use tokio_util::codec::{Framed, FramedRead, FramedWrite};
33use tracing::{error, Instrument};
34use url::Url;
35use uuid::Uuid;
36
37mod codec;
38pub(crate) mod config;
39
40pub(crate) enum ReplCtrl {
41    GetCertificate {
42        respond: oneshot::Sender<X509>,
43    },
44    RenewCertificate {
45        respond: oneshot::Sender<bool>,
46    },
47    RefreshConsumer {
48        respond: oneshot::Sender<mpsc::Receiver<()>>,
49    },
50}
51
52#[derive(Debug, Clone)]
53enum ReplConsumerCtrl {
54    Stop,
55    Refresh(Arc<Mutex<(bool, mpsc::Sender<()>)>>),
56}
57
58pub(crate) async fn create_repl_server(
59    idms: Arc<IdmServer>,
60    repl_config: &ReplicationConfiguration,
61    rx: broadcast::Receiver<CoreAction>,
62) -> Result<(tokio::task::JoinHandle<()>, mpsc::Sender<ReplCtrl>), ()> {
63    // We need to start the tcp listener. This will persist over ssl reloads!
64    let listener = TcpListener::bind(&repl_config.bindaddress)
65        .await
66        .map_err(|e| {
67            error!(
68                "Could not bind to replication address {} -> {:?}",
69                repl_config.bindaddress, e
70            );
71        })?;
72
73    // Create the control channel. Use a low msg count, there won't be that much going on.
74    let (ctrl_tx, ctrl_rx) = mpsc::channel(4);
75
76    // We need to start the tcp listener. This will persist over ssl reloads!
77    info!(
78        "Starting replication interface https://{} ...",
79        repl_config.bindaddress
80    );
81    let repl_handle: JoinHandle<()> = tokio::spawn(repl_acceptor(
82        listener,
83        idms,
84        repl_config.clone(),
85        rx,
86        ctrl_rx,
87    ));
88
89    info!("Created replication interface");
90    Ok((repl_handle, ctrl_tx))
91}
92
93#[instrument(level = "debug", skip_all)]
94/// This returns the remote address that worked, so you can try that first next time
95async fn repl_consumer_connect_supplier(
96    server_name: &ServerName<'static>,
97    sock_addrs: &[SocketAddr],
98    tls_connector: &TlsConnector,
99    consumer_conn_settings: &ConsumerConnSettings,
100) -> Option<(
101    SocketAddr,
102    Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
103)> {
104    // This is pretty gnarly, but we need to loop to try out each socket addr.
105    for sock_addr in sock_addrs {
106        debug!(
107            "Attempting to connect to {} replica via {}",
108            server_name.to_str(),
109            sock_addr
110        );
111
112        let tcpstream = match timeout(
113            consumer_conn_settings.replica_connect_timeout,
114            TcpStream::connect(sock_addr),
115        )
116        .await
117        {
118            Ok(Ok(tc)) => {
119                trace!("Connection established to peer on {:?}", sock_addr);
120                tc
121            }
122            Ok(Err(err)) => {
123                debug!(?err, "Failed to connect to {}", sock_addr);
124                continue;
125            }
126            Err(_) => {
127                debug!("Timeout connecting to {}", sock_addr);
128                continue;
129            }
130        };
131
132        let tlsstream = match tls_connector
133            .connect(server_name.to_owned(), tcpstream)
134            .await
135        {
136            Ok(ta) => ta,
137            Err(e) => {
138                error!("Replication client TLS setup error, continuing -> {:?}", e);
139                continue;
140            }
141        };
142
143        let supplier_conn = Framed::new(
144            tlsstream,
145            codec::ConsumerCodec::new(consumer_conn_settings.max_frame_bytes),
146        );
147        // "hey this one worked, try it first next time!"
148        return Some((sock_addr.to_owned(), supplier_conn));
149    }
150
151    error!(
152        "Unable to connect to supplier, tried to connect to {:?}",
153        sock_addrs
154    );
155    None
156}
157
158async fn repl_consumer_disconnect_supplier(
159    supplier_conn: Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
160) {
161    let mut tls_stream = supplier_conn.into_inner();
162    if let Err(tls_err) = tls_stream.shutdown().await {
163        warn!(?tls_err, "Unable to cleanly shutdown TLS client connection");
164    }
165}
166
167/// This returns the socket address that worked, so you can try that first next time
168#[instrument(
169    level="info",
170    skip(refresh_coord, tls_connector, idms, consumer_conn_settings),
171    fields(eventid = Uuid::new_v4().to_string(), server_name = %server_name.to_str())
172)]
173async fn repl_run_consumer_refresh(
174    refresh_coord: Arc<Mutex<(bool, mpsc::Sender<()>)>>,
175    server_name: &ServerName<'static>,
176    sock_addrs: &[SocketAddr],
177    tls_connector: &TlsConnector,
178    idms: &IdmServer,
179    consumer_conn_settings: &ConsumerConnSettings,
180) -> Result<Option<SocketAddr>, ()> {
181    // Take the refresh lock. Note that every replication consumer *should* end up here
182    // behind this lock, but only one can proceed. This is what we want!
183
184    let refresh_coord_guard = refresh_coord.lock().await;
185
186    // Simple case - task is already done.
187    if refresh_coord_guard.0 {
188        trace!("Refresh already completed by another task, return.");
189        return Ok(None);
190    }
191
192    // Okay, we need to proceed. Open the connection.
193    let (addr, mut supplier_conn) = repl_consumer_connect_supplier(
194        server_name,
195        sock_addrs,
196        tls_connector,
197        consumer_conn_settings,
198    )
199    .await
200    .ok_or(())?;
201
202    let result =
203        repl_run_consumer_refresh_inner(addr, &mut supplier_conn, refresh_coord_guard, idms).await;
204
205    // disconnect the connection if possible.
206    repl_consumer_disconnect_supplier(supplier_conn).await;
207
208    result
209}
210
211async fn repl_run_consumer_refresh_inner(
212    addr: SocketAddr,
213    supplier_conn: &mut Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
214    mut refresh_coord_guard: MutexGuard<'_, (bool, mpsc::Sender<()>)>,
215    idms: &IdmServer,
216) -> Result<Option<SocketAddr>, ()> {
217    // If we fail at any point, just RETURN because this leaves the next task to attempt, or
218    // the channel drops and that tells the caller this failed.
219    supplier_conn
220        .send(ConsumerRequest::Refresh)
221        .await
222        .map_err(|err| error!(?err, "consumer encode error, unable to continue."))?;
223
224    let refresh = if let Some(codec_msg) = supplier_conn.next().await {
225        match codec_msg.map_err(|err| error!(?err, "Consumer decode error, unable to continue."))? {
226            SupplierResponse::Refresh(changes) => {
227                // Success - return to bypass the error message.
228                changes
229            }
230            SupplierResponse::Pong | SupplierResponse::Incremental(_) => {
231                error!("Supplier Response contains invalid State");
232                return Err(());
233            }
234        }
235    } else {
236        error!("Connection closed");
237        return Err(());
238    };
239
240    // Now apply the refresh if possible
241    {
242        // Scope the transaction.
243        let ct = duration_from_epoch_now();
244        idms.proxy_write(ct)
245            .await
246            .and_then(|mut write_txn| {
247                write_txn
248                    .qs_write
249                    .consumer_apply_refresh(refresh)
250                    .and_then(|cs| write_txn.commit().map(|()| cs))
251            })
252            .map_err(|err| error!(?err, "Consumer was not able to apply refresh."))?;
253    }
254
255    // Now mark the refresh as complete AND indicate it to the channel.
256    refresh_coord_guard.0 = true;
257    if refresh_coord_guard.1.send(()).await.is_err() {
258        warn!("Unable to signal to caller that refresh has completed.");
259    }
260
261    // Here the coord guard will drop and every other task proceeds.
262
263    info!("Replication refresh was successful.");
264    Ok(Some(addr))
265}
266
267#[instrument(
268    level="info",
269    skip(tls_connector, idms, consumer_conn_settings, server_name),
270    fields(eventid = Uuid::new_v4().to_string(), server_name = %server_name.to_str())
271)]
272async fn repl_run_consumer(
273    server_name: &ServerName<'static>,
274    sock_addrs: &[SocketAddr],
275    tls_connector: &TlsConnector,
276    automatic_refresh: bool,
277    idms: &IdmServer,
278    consumer_conn_settings: &ConsumerConnSettings,
279) -> Option<SocketAddr> {
280    let (socket_addr, mut supplier_conn) = repl_consumer_connect_supplier(
281        server_name,
282        sock_addrs,
283        tls_connector,
284        consumer_conn_settings,
285    )
286    .await?;
287
288    let result =
289        repl_run_consumer_inner(socket_addr, &mut supplier_conn, idms, automatic_refresh).await;
290
291    repl_consumer_disconnect_supplier(supplier_conn).await;
292
293    result
294}
295
296async fn repl_run_consumer_inner(
297    socket_addr: SocketAddr,
298    supplier_conn: &mut Framed<TlsStream<TcpStream>, codec::ConsumerCodec>,
299    idms: &IdmServer,
300    automatic_refresh: bool,
301) -> Option<SocketAddr> {
302    // Perform incremental.
303    let consumer_ruv_range = {
304        let consumer_state = idms
305            .proxy_read()
306            .await
307            .and_then(|mut read_txn| read_txn.qs_read.consumer_get_state());
308        match consumer_state {
309            Ok(ruv_range) => ruv_range,
310            Err(err) => {
311                error!(
312                    ?err,
313                    "consumer ruv range could not be accessed, unable to continue."
314                );
315                return None;
316            }
317        }
318    };
319
320    if let Err(err) = supplier_conn
321        .send(ConsumerRequest::Incremental(consumer_ruv_range))
322        .await
323    {
324        error!(?err, "consumer encode error, unable to continue.");
325        return None;
326    }
327
328    let changes = if let Some(codec_msg) = supplier_conn.next().await {
329        match codec_msg {
330            Ok(SupplierResponse::Incremental(changes)) => {
331                // Success - return to bypass the error message.
332                changes
333            }
334            Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Refresh(_)) => {
335                error!("Supplier Response contains invalid state");
336                return None;
337            }
338            Err(err) => {
339                error!(?err, "Consumer decode error, unable to continue.");
340                return None;
341            }
342        }
343    } else {
344        error!("Connection closed");
345        return None;
346    };
347
348    // Now apply the changes if possible
349    let consumer_state = {
350        let ct = duration_from_epoch_now();
351        match idms.proxy_write(ct).await.and_then(|mut write_txn| {
352            write_txn
353                .qs_write
354                .consumer_apply_changes(changes)
355                .and_then(|cs| write_txn.commit().map(|()| cs))
356        }) {
357            Ok(state) => state,
358            Err(err) => {
359                error!(?err, "Consumer was not able to apply changes.");
360                return None;
361            }
362        }
363    };
364
365    match consumer_state {
366        ConsumerState::Ok => {
367            info!("Incremental Replication Success");
368            // return to bypass the failure message.
369            return Some(socket_addr);
370        }
371        ConsumerState::RefreshRequired => {
372            if automatic_refresh {
373                warn!("Consumer is out of date and must be refreshed. This will happen *now*.");
374            } else {
375                error!("Consumer is out of date and must be refreshed. You must manually resolve this situation.");
376                return None;
377            };
378        }
379    }
380
381    if let Err(err) = supplier_conn.send(ConsumerRequest::Refresh).await {
382        error!(?err, "consumer encode error, unable to continue.");
383        return None;
384    }
385
386    let refresh = if let Some(codec_msg) = supplier_conn.next().await {
387        match codec_msg {
388            Ok(SupplierResponse::Refresh(changes)) => {
389                // Success - return to bypass the error message.
390                changes
391            }
392            Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Incremental(_)) => {
393                error!("Supplier Response contains invalid State");
394                return None;
395            }
396            Err(err) => {
397                error!(?err, "consumer decode error, unable to continue.");
398                return None;
399            }
400        }
401    } else {
402        error!("Connection closed");
403        return None;
404    };
405
406    // Now apply the refresh if possible
407    let ct = duration_from_epoch_now();
408    if let Err(err) = idms.proxy_write(ct).await.and_then(|mut write_txn| {
409        write_txn
410            .qs_write
411            .consumer_apply_refresh(refresh)
412            .and_then(|cs| write_txn.commit().map(|()| cs))
413    }) {
414        error!(?err, "consumer was not able to apply refresh.");
415        return None;
416    }
417
418    info!("Replication refresh was successful.");
419    Some(socket_addr)
420}
421
422#[derive(Debug, Clone)]
423struct ConsumerConnSettings {
424    max_frame_bytes: usize,
425    task_poll_interval: Duration,
426    replica_connect_timeout: Duration,
427}
428
429#[allow(clippy::too_many_arguments)]
430async fn repl_task(
431    origin: Url,
432
433    client_key: PrivateKeyDer<'static>,
434    client_cert: CertificateDer<'static>,
435    supplier_cert: CertificateDer<'static>,
436
437    consumer_conn_settings: ConsumerConnSettings,
438    mut task_rx: broadcast::Receiver<ReplConsumerCtrl>,
439    automatic_refresh: bool,
440    idms: Arc<IdmServer>,
441) {
442    if origin.scheme() != "repl" {
443        error!("Replica origin is not repl:// - refusing to proceed.");
444        return;
445    }
446
447    let domain = match origin.domain() {
448        Some(d) => d,
449        None => {
450            error!("Replica origin does not have a valid domain name, unable to proceed. Perhaps you tried to use an ip address?");
451            return;
452        }
453    };
454
455    let Ok(server_name) = ServerName::try_from(domain.to_owned()) else {
456        error!("Replica origin does not have a valid domain name, unable to proceed.");
457        return;
458    };
459
460    // Add the supplier cert.
461    // ⚠️  note that here we need to build a new cert store. This is because
462    // we want to pin a single certificate!
463    let mut root_cert_store = RootCertStore::empty();
464    if let Err(err) = root_cert_store.add(supplier_cert) {
465        error!(?err, "Replica supplier cert invalid.");
466        return;
467    };
468
469    let provider = rustls::crypto::aws_lc_rs::default_provider().into();
470
471    let tls_client_config = match ClientConfig::builder_with_provider(provider)
472        .with_safe_default_protocol_versions()
473        .and_then(|builder| {
474            builder
475                .with_root_certificates(root_cert_store)
476                .with_client_auth_cert(vec![client_cert], client_key)
477        }) {
478        Ok(ccb) => ccb,
479        Err(err) => {
480            error!(?err, "Unable to build TLS client configuration");
481            return;
482        }
483    };
484
485    let tls_connector = TlsConnector::from(Arc::new(tls_client_config));
486
487    let mut repl_interval = interval(consumer_conn_settings.task_poll_interval);
488
489    info!("Replica task for {} has started.", origin);
490
491    // we keep track of the "last known good" socketaddr so we can try that first next time.
492    let mut last_working_address: Option<SocketAddr> = None;
493
494    // Okay, all the parameters are set up. Now we replicate on our interval.
495    loop {
496        // we resolve the DNS entry to the ip:port each time we attempt a connection to avoid stale
497        // DNS issues, ref #3188. If we are unable to resolve the address, we backoff and try again
498        // as in something like docker the address may change frequently.
499        //
500        // Note, if DNS isn't available, we can proceed with the last used working address too. This
501        // prevents DNS (or lack thereof) from causing a replication outage.
502        let mut sorted_socket_addrs = vec![];
503
504        // If the target address worked last time, then let's use it this time!
505        if let Some(addr) = last_working_address {
506            debug!(?last_working_address);
507            sorted_socket_addrs.push(addr);
508        };
509
510        // Default to port 443 if not set in the origin
511        match origin.socket_addrs(|| Some(443)) {
512            Ok(mut socket_addrs) => {
513                // Make every address unique.
514                socket_addrs.sort_unstable();
515                socket_addrs.dedup();
516
517                // The only possible conflict is with the last working address,
518                // so lets just check that.
519                socket_addrs.into_iter().for_each(|addr| {
520                    if Some(&addr) != last_working_address.as_ref() {
521                        // Not already present, append
522                        sorted_socket_addrs.push(addr);
523                    }
524                });
525            }
526            Err(err) => {
527                if let Some(addr) = last_working_address {
528                    warn!(
529                        ?err,
530                        "Unable to resolve '{origin}' to ip:port, using last known working address '{addr}'"
531                    );
532                } else {
533                    warn!(?err, "Unable to resolve '{origin}' to ip:port.");
534                }
535            }
536        };
537
538        if sorted_socket_addrs.is_empty() {
539            warn!(
540                "No replication addresses available, delaying replication operation for '{origin}'"
541            );
542            repl_interval.tick().await;
543            continue;
544        }
545
546        tokio::select! {
547            Ok(task) = task_rx.recv() => {
548                match task {
549                    ReplConsumerCtrl::Stop => break,
550                    ReplConsumerCtrl::Refresh ( refresh_coord ) => {
551                        last_working_address = (repl_run_consumer_refresh(
552                            refresh_coord,
553                            &server_name,
554                            &sorted_socket_addrs,
555                            &tls_connector,
556                            &idms,
557                            &consumer_conn_settings
558                        )
559                        .await).unwrap_or_default();
560                    }
561                }
562            }
563            _ = repl_interval.tick() => {
564                // Interval passed, attempt a replication run.
565                repl_run_consumer(
566                    &server_name,
567                    &sorted_socket_addrs,
568                    &tls_connector,
569                    automatic_refresh,
570                    &idms,
571                    &consumer_conn_settings
572                )
573                .await;
574            }
575        }
576    }
577
578    info!("Replica task for {} has stopped.", origin);
579}
580
581#[instrument(level = "debug", skip_all)]
582async fn handle_repl_conn(
583    max_frame_bytes: usize,
584    tcpstream: TcpStream,
585    client_address: SocketAddr,
586    tls_acceptor: TlsAcceptor,
587    idms: Arc<IdmServer>,
588) {
589    debug!(?client_address, "replication client connected 🛫");
590
591    let tlsstream = match tls_acceptor.accept(tcpstream).await {
592        Ok(ta) => ta,
593        Err(err) => {
594            error!(?err, "Replication TLS setup error, disconnecting client");
595            return;
596        }
597    };
598
599    let (r, w) = tokio::io::split(tlsstream);
600    let mut r = FramedRead::new(r, codec::SupplierCodec::new(max_frame_bytes));
601    let mut w = FramedWrite::new(w, codec::SupplierCodec::new(max_frame_bytes));
602
603    while let Some(codec_msg) = r.next().await {
604        match codec_msg {
605            Ok(ConsumerRequest::Ping) => {
606                debug!("consumer requested ping");
607                if let Err(err) = w.send(SupplierResponse::Pong).await {
608                    error!(?err, "supplier encode error, unable to continue.");
609                    break;
610                }
611            }
612            Ok(ConsumerRequest::Incremental(consumer_ruv_range)) => {
613                let changes = match idms.proxy_read().await.and_then(|mut read_txn| {
614                    read_txn
615                        .qs_read
616                        .supplier_provide_changes(consumer_ruv_range)
617                }) {
618                    Ok(changes) => changes,
619                    Err(err) => {
620                        error!(?err, "supplier provide changes failed.");
621                        break;
622                    }
623                };
624
625                if let Err(err) = w.send(SupplierResponse::Incremental(changes)).await {
626                    error!(?err, "supplier encode error, unable to continue.");
627                    break;
628                }
629            }
630            Ok(ConsumerRequest::Refresh) => {
631                let changes = match idms
632                    .proxy_read()
633                    .await
634                    .and_then(|mut read_txn| read_txn.qs_read.supplier_provide_refresh())
635                {
636                    Ok(changes) => changes,
637                    Err(err) => {
638                        error!(?err, "supplier provide refresh failed.");
639                        break;
640                    }
641                };
642
643                if let Err(err) = w.send(SupplierResponse::Refresh(changes)).await {
644                    error!(?err, "supplier encode error, unable to continue.");
645                    break;
646                }
647            }
648            Err(err) => {
649                error!(?err, "supplier decode error, unable to continue.");
650                break;
651            }
652        }
653    }
654
655    debug!(?client_address, "replication client disconnected 🛬");
656}
657
658/// This is the main acceptor for the replication server.
659async fn repl_acceptor(
660    listener: TcpListener,
661    idms: Arc<IdmServer>,
662    repl_config: ReplicationConfiguration,
663    mut rx: broadcast::Receiver<CoreAction>,
664    mut ctrl_rx: mpsc::Receiver<ReplCtrl>,
665) {
666    info!("Starting Replication Acceptor ...");
667    // Persistent parts
668    // These all probably need changes later ...
669    let replica_connect_timeout = Duration::from_secs(2);
670    let mut retry_timeout = Duration::from_secs(1);
671    let max_frame_bytes = 268435456;
672
673    let consumer_conn_settings = ConsumerConnSettings {
674        max_frame_bytes,
675        task_poll_interval: repl_config.get_task_poll_interval(),
676        replica_connect_timeout,
677    };
678
679    // Setup a broadcast to control our tasks.
680    let (task_tx, task_rx1) = broadcast::channel(1);
681    // Note, we drop this task here since each task will re-subscribe. That way the
682    // broadcast doesn't jam up because we aren't draining this task.
683    drop(task_rx1);
684    let mut task_handles = VecDeque::new();
685
686    // Create another broadcast to control the replication tasks and their need to reload.
687
688    // Spawn a KRC communication task?
689
690    // In future we need to update this from the KRC if configured, and we default this
691    // to "empty". But if this map exists in the config, we have to always use that.
692    let replication_node_map = repl_config.manual.clone();
693    let domain_name = match repl_config.origin.domain() {
694        Some(n) => n.to_string(),
695        None => {
696            error!("Unable to start replication, replication origin does not contain a valid domain name.");
697            return;
698        }
699    };
700
701    // This needs to have an event loop that can respond to changes.
702    // For now we just design it to reload ssl if the map changes internally.
703    'event: loop {
704        // Don't block shutdowns while we are waiting here.
705        tokio::select! {
706            Ok(action) = rx.recv() => {
707                match action {
708                    CoreAction::Shutdown => break 'event,
709                }
710            }
711            _ = sleep(retry_timeout) => {}
712        }
713
714        // The timeout is initially small, we increase it here to prevent spinning too much.
715        retry_timeout = Duration::from_secs(60);
716
717        info!("Starting replication reload ...");
718        // Tell existing tasks to shutdown.
719        // Note: We ignore the result here since an err can occur *if* there are
720        // no tasks currently listening on the channel.
721        info!("Stopping {} Replication Tasks ...", task_handles.len());
722        debug_assert!(task_handles.len() >= task_tx.receiver_count());
723        let _ = task_tx.send(ReplConsumerCtrl::Stop);
724        for task_handle in task_handles.drain(..) {
725            // Let each task join.
726            let res: Result<(), _> = task_handle.await;
727            if res.is_err() {
728                warn!("Failed to join replication task, continuing ...");
729            }
730        }
731
732        // Now we can start to re-load configurations and setup our client tasks
733        // as well.
734
735        // Get our private key / cert.
736        let res = {
737            let ct = duration_from_epoch_now();
738            idms.proxy_write(ct).await.and_then(|mut idms_prox_write| {
739                idms_prox_write
740                    .qs_write
741                    .supplier_get_key_cert(&domain_name)
742                    .and_then(|res| idms_prox_write.commit().map(|()| res))
743            })
744        };
745
746        let (server_key, server_cert) = match res {
747            Ok(r) => r,
748            Err(err) => {
749                error!(?err, "CRITICAL: Unable to access supplier certificate/key.");
750                continue 'event;
751            }
752        };
753
754        info!(
755            replication_cert_not_before = ?server_cert.not_before(),
756            replication_cert_not_after = ?server_cert.not_after(),
757        );
758
759        // rustls expects these to be der
760        let Ok(server_key_der) = server_key.private_key_to_der() else {
761            error!("CRITICAL: Unable to convert server key to DER.");
762            continue 'event;
763        };
764
765        let Ok(server_key_der) = PrivateKeyDer::try_from(server_key_der) else {
766            error!("CRITICAL: Unable to convert server key from DER.");
767            continue 'event;
768        };
769
770        let Ok(server_cert_der) = server_cert.to_der().map(CertificateDer::from) else {
771            error!("CRITICAL: Unable to convert server cert to DER.");
772            continue 'event;
773        };
774
775        let mut client_certs = Vec::new();
776
777        // For each node in the map, either spawn a task to pull from that node,
778        // or setup the node as allowed to pull from us.
779        for (origin, node) in replication_node_map.iter() {
780            // Setup client certs
781            match node {
782                RepNodeConfig::MutualPull {
783                    partner_cert: consumer_cert,
784                    automatic_refresh: _,
785                }
786                | RepNodeConfig::AllowPull { consumer_cert } => {
787                    let Ok(consumer_cert_der) = consumer_cert.to_der().map(CertificateDer::from)
788                    else {
789                        warn!("WARNING: Unable to convert client cert to DER.");
790                        continue 'event;
791                    };
792
793                    client_certs.push(consumer_cert_der)
794                }
795                RepNodeConfig::Pull {
796                    supplier_cert: _,
797                    automatic_refresh: _,
798                } => {}
799            };
800
801            match node {
802                RepNodeConfig::MutualPull {
803                    partner_cert: supplier_cert,
804                    automatic_refresh,
805                }
806                | RepNodeConfig::Pull {
807                    supplier_cert,
808                    automatic_refresh,
809                } => {
810                    let Ok(supplier_cert_der) = supplier_cert.to_der().map(CertificateDer::from)
811                    else {
812                        warn!("WARNING: Unable to convert client cert to DER.");
813                        continue 'event;
814                    };
815
816                    let task_rx = task_tx.subscribe();
817
818                    let handle: JoinHandle<()> = tokio::spawn(repl_task(
819                        origin.clone(),
820                        server_key_der.clone_key(),
821                        server_cert_der.clone(),
822                        supplier_cert_der.clone(),
823                        consumer_conn_settings.clone(),
824                        task_rx,
825                        *automatic_refresh,
826                        idms.clone(),
827                    ));
828
829                    task_handles.push_back(handle);
830                    debug_assert_eq!(task_handles.len(), task_tx.receiver_count());
831                }
832                RepNodeConfig::AllowPull { consumer_cert: _ } => {}
833            };
834        }
835
836        // ⚠️  This section is critical to the security of replication
837        //    Since replication relies on mTLS we MUST ensure these options
838        //    are absolutely correct!
839        //
840        // Setup the TLS builder.
841
842        // ⚠️  CRITICAL - ensure that the cert store only has client certs from
843        // the repl map added.
844
845        let tls_acceptor = if client_certs.is_empty() {
846            warn!("No replication client certs are available, replication connections will be ignored.");
847            None
848        } else {
849            let mut client_cert_roots = RootCertStore::empty();
850
851            for client_cert in client_certs.into_iter() {
852                if let Err(err) = client_cert_roots.add(client_cert) {
853                    error!(?err, "CRITICAL, unable to add client certificate.");
854                    continue 'event;
855                }
856            }
857
858            let provider: Arc<_> = rustls::crypto::aws_lc_rs::default_provider().into();
859
860            let client_cert_verifier_result = WebPkiClientVerifier::builder_with_provider(
861                client_cert_roots.into(),
862                provider.clone(),
863            )
864            // We don't allow clients that lack a certificate to correct.
865            // allow_unauthenticated()
866            .build();
867
868            let client_cert_verifier = match client_cert_verifier_result {
869                Ok(ccv) => ccv,
870                Err(err) => {
871                    error!(
872                        ?err,
873                        "CRITICAL, unable to configure client certificate verifier."
874                    );
875                    continue 'event;
876                }
877            };
878
879            let tls_server_config = match ServerConfig::builder_with_provider(provider)
880                .with_safe_default_protocol_versions()
881                .and_then(|builder| {
882                    builder
883                        .with_client_cert_verifier(client_cert_verifier)
884                        .with_single_cert(vec![server_cert_der], server_key_der)
885                }) {
886                Ok(tls_server_config) => tls_server_config,
887                Err(err) => {
888                    error!(
889                        ?err,
890                        "CRITICAL, unable to create TLS Server Config. Will retry ..."
891                    );
892                    continue 'event;
893                }
894            };
895
896            Some(TlsAcceptor::from(Arc::new(tls_server_config)))
897        };
898
899        loop {
900            // This is great to diagnose when spans are entered or present and they capture
901            // things incorrectly.
902            // eprintln!("🔥 C ---> {:?}", tracing::Span::current());
903            let eventid = Uuid::new_v4();
904
905            tokio::select! {
906                Ok(action) = rx.recv() => {
907                    match action {
908                        CoreAction::Shutdown => break 'event,
909                    }
910                }
911                Some(ctrl_msg) = ctrl_rx.recv() => {
912                    match ctrl_msg {
913                        ReplCtrl::GetCertificate {
914                            respond
915                        } => {
916                            let _span = debug_span!("supplier_accept_loop", uuid = ?eventid).entered();
917                            if respond.send(server_cert.clone()).is_err() {
918                                warn!("Server certificate was requested, but requsetor disconnected");
919                            } else {
920                                trace!("Sent server certificate via control channel");
921                            }
922                        }
923                        ReplCtrl::RenewCertificate {
924                            respond
925                        } => {
926                            let span = debug_span!("supplier_accept_loop", uuid = ?eventid);
927                            async {
928                                debug!("renewing replication certificate ...");
929                                // Renew the cert.
930                                let res = {
931                                    let ct = duration_from_epoch_now();
932                                    idms.proxy_write(ct).await
933                                        .and_then(|mut idms_prox_write|
934                                    idms_prox_write
935                                        .qs_write
936                                        .supplier_renew_key_cert(&domain_name)
937                                        .and_then(|res| idms_prox_write.commit().map(|()| res))
938                                        )
939                                };
940
941                                let success = res.is_ok();
942
943                                if let Err(err) = res {
944                                    error!(?err, "failed to renew server certificate");
945                                }
946
947                                if respond.send(success).is_err() {
948                                    warn!("Server certificate renewal was requested, but requester disconnected!");
949                                } else {
950                                    trace!("Sent server certificate renewal status via control channel");
951                                }
952                            }
953                            .instrument(span)
954                            .await;
955
956                            // Start a reload.
957                            continue 'event;
958                        }
959                        ReplCtrl::RefreshConsumer {
960                            respond
961                        } => {
962                            // Indicate to consumer tasks that they should do a refresh.
963                            let (tx, rx) = mpsc::channel(1);
964
965                            let refresh_coord = Arc::new(
966                                Mutex::new(
967                                (
968                                    false, tx
969                                )
970                                )
971                            );
972
973                            if task_tx.send(ReplConsumerCtrl::Refresh(refresh_coord)).is_err() {
974                                error!("Unable to begin replication consumer refresh, tasks are unable to be notified.");
975                            }
976
977                            if respond.send(rx).is_err() {
978                                warn!("Replication consumer refresh was requested, but requester disconnected");
979                            } else {
980                                trace!("Sent refresh comms channel to requester");
981                            }
982                        }
983                    }
984                }
985                // Handle accepts.
986                // Handle *reloads*
987                /*
988                _ = reload.recv() => {
989                    info!("Initiating TLS reload");
990                    continue
991                }
992                */
993                accept_result = listener.accept() => {
994                    match accept_result {
995                        Ok((tcpstream, client_socket_addr)) => {
996                            if let Some(clone_tls_acceptor) = tls_acceptor.clone() {
997                                let clone_idms = idms.clone();
998                                // We don't care about the join handle here - once a client connects
999                                // it sticks to whatever ssl settings it had at launch.
1000                                tokio::spawn(
1001                                    handle_repl_conn(max_frame_bytes, tcpstream, client_socket_addr, clone_tls_acceptor, clone_idms)
1002                                );
1003                            } else {
1004                                // TLS is not setup, generally due to no accepted/trusted client
1005                                // certs being present. Drop the connection.
1006                                warn!("Ignoring connection from {client_socket_addr} as replication is not configured correctly.");
1007                                warn!("This is because you have not configured this server with trusted partner certificates.");
1008                            }
1009                        }
1010                        Err(e) => {
1011                            error!("replication acceptor error, continuing -> {:?}", e);
1012                        }
1013                    }
1014                }
1015            } // end select
1016              // Continue to poll/loop
1017        }
1018    }
1019    // Shutdown child tasks.
1020    info!("Stopping {} Replication Tasks ...", task_handles.len());
1021    debug_assert!(task_handles.len() >= task_tx.receiver_count());
1022    let _ = task_tx.send(ReplConsumerCtrl::Stop);
1023    for task_handle in task_handles.drain(..) {
1024        // Let each task join.
1025        let res: Result<(), _> = task_handle.await.map(|_| ());
1026        if res.is_err() {
1027            warn!("Failed to join replication task, continuing ...");
1028        }
1029    }
1030
1031    info!("Stopped {}", super::TaskName::Replication);
1032}