Implement reconnect logic for client (untested)

pull/146/head
Chip Senkbeil 2 years ago
parent 24518f9882
commit e4ebd87718
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,15 +1,14 @@
use crate::common::{Connection, Interest, Reconnectable, Request, Transport, UntypedResponse};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{
io,
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
};
use tokio::{
io,
sync::{mpsc, oneshot},
sync::mpsc,
task::{JoinError, JoinHandle},
};
@ -19,6 +18,9 @@ pub use builder::*;
mod channel;
pub use channel::*;
mod reconnect;
pub use reconnect::*;
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
const SLEEP_DURATION: Duration = Duration::from_millis(50);
@ -27,9 +29,6 @@ pub struct Client<T, U> {
/// Used to send requests to a server
channel: Channel<T, U>,
/// Used to send reconnect request to inner transport
reconnect_tx: mpsc::Sender<oneshot::Sender<io::Result<()>>>,
/// Used to send shutdown request to inner transport
shutdown_tx: mpsc::Sender<()>,
@ -43,14 +42,13 @@ where
U: Send + Sync + DeserializeOwned + 'static,
{
/// Spawns a client using the provided [`Connection`].
fn spawn<V>(mut connection: Connection<V>) -> Self
fn spawn<V>(mut connection: Connection<V>, mut strategy: ReconnectStrategy) -> Self
where
V: Transport + Send + Sync + 'static,
{
let post_office = Arc::new(PostOffice::default());
let weak_post_office = Arc::downgrade(&post_office);
let (tx, mut rx) = mpsc::channel::<Request<T>>(1);
let (reconnect_tx, mut reconnect_rx) = mpsc::channel::<oneshot::Sender<io::Result<()>>>(1);
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
// Ensure that our transport starts off clean (nothing in buffers or backup)
@ -59,34 +57,35 @@ where
// Start a task that continually checks for responses and delivers them using the
// post office
let task = tokio::spawn(async move {
let mut needs_retry = false;
loop {
if needs_retry {
info!("Client encountered issue, attempting to reconnect");
if log::log_enabled!(log::Level::Debug) {
debug!("Using strategy {strategy:?}");
}
match strategy.reconnect(&mut connection).await {
Ok(x) => x,
Err(x) => {
error!("Unable to re-establish connection: {x}");
break;
}
}
}
let ready = tokio::select! {
_ = shutdown_rx.recv() => {
debug!("Client got shutdown signal, so exiting event loop");
break;
}
cb = reconnect_rx.recv() => {
debug!("Client got reconnect signal, so attempting to reconnect");
if let Some(cb) = cb {
let _ = match Reconnectable::reconnect(&mut connection).await {
Ok(()) => cb.send(Ok(())),
Err(x) => {
error!("Client reconnect failed: {x}");
cb.send(Err(x))
}
};
continue;
} else {
error!("Client callback for reconnect missing! Corrupt state!");
break;
}
}
result = connection.ready(Interest::READABLE | Interest::WRITABLE) => {
match result {
Ok(result) => result,
Err(x) => {
error!("Failed to examine ready state: {x}");
break;
needs_retry = true;
continue;
}
}
}
@ -126,7 +125,8 @@ where
}
Ok(None) => {
debug!("Connection closed");
break;
needs_retry = true;
continue;
}
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => {
@ -184,7 +184,6 @@ where
Self {
channel,
reconnect_tx,
shutdown_tx,
task,
}
@ -248,26 +247,6 @@ impl<T, U> Client<T, U> {
}
}
#[async_trait]
impl<T, U> Reconnectable for Client<T, U>
where
T: Send,
U: Send + Sync,
{
async fn reconnect(&mut self) -> io::Result<()> {
let (tx, rx) = oneshot::channel();
if self.reconnect_tx.send(tx).await.is_ok() {
rx.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Callback lost"))?
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Client internal task dead",
))
}
}
}
impl<T, U> Deref for Client<T, U> {
type Target = Channel<T, U>;

@ -13,7 +13,7 @@ mod windows;
#[cfg(windows)]
pub use windows::*;
use crate::client::Client;
use crate::client::{Client, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, Connection, Transport};
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, future::Future, io, time::Duration};
@ -21,6 +21,7 @@ use std::{convert, future::Future, io, time::Duration};
/// Builder for a [`Client`]
pub struct ClientBuilder<H, T> {
auth_handler: H,
reconnect_strategy: ReconnectStrategy,
transport: T,
timeout: Option<Duration>,
}
@ -29,6 +30,16 @@ impl<H, T> ClientBuilder<H, T> {
pub fn auth_handler<U>(self, auth_handler: U) -> ClientBuilder<U, T> {
ClientBuilder {
auth_handler,
reconnect_strategy: self.reconnect_strategy,
transport: self.transport,
timeout: self.timeout,
}
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> ClientBuilder<H, T> {
ClientBuilder {
auth_handler: self.auth_handler,
reconnect_strategy,
transport: self.transport,
timeout: self.timeout,
}
@ -51,6 +62,7 @@ impl<H, T> ClientBuilder<H, T> {
pub fn transport<U>(self, transport: U) -> ClientBuilder<H, U> {
ClientBuilder {
auth_handler: self.auth_handler,
reconnect_strategy: self.reconnect_strategy,
transport,
timeout: self.timeout,
}
@ -59,6 +71,7 @@ impl<H, T> ClientBuilder<H, T> {
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
reconnect_strategy: self.reconnect_strategy,
transport: self.transport,
timeout: timeout.into(),
}
@ -69,6 +82,7 @@ impl ClientBuilder<(), ()> {
pub fn new() -> Self {
Self {
auth_handler: (),
reconnect_strategy: ReconnectStrategy::default(),
transport: (),
timeout: None,
}
@ -92,12 +106,13 @@ where
V: Send + Sync + DeserializeOwned + 'static,
{
let auth_handler = self.auth_handler;
let retry_strategy = self.reconnect_strategy;
let timeout = self.timeout;
let transport = self.transport;
let f = async move {
let connection = Connection::client(transport, auth_handler).await?;
Ok(Client::spawn(connection))
Ok(Client::spawn(connection, retry_strategy))
};
match timeout {

@ -1,4 +1,4 @@
use crate::client::{Client, ClientBuilder};
use crate::client::{Client, ClientBuilder, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, TcpTransport};
use serde::{de::DeserializeOwned, Serialize};
use tokio::{io, net::ToSocketAddrs, time::Duration};
@ -11,6 +11,10 @@ impl<T> TcpClientBuilder<T> {
TcpClientBuilder(self.0.auth_handler(auth_handler))
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> TcpClientBuilder<T> {
TcpClientBuilder(self.0.reconnect_strategy(reconnect_strategy))
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self(self.0.timeout(timeout))
}

@ -1,4 +1,4 @@
use crate::client::{Client, ClientBuilder};
use crate::client::{Client, ClientBuilder, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, UnixSocketTransport};
use serde::{de::DeserializeOwned, Serialize};
use std::path::Path;
@ -12,6 +12,13 @@ impl<T> UnixSocketClientBuilder<T> {
UnixSocketClientBuilder(self.0.auth_handler(auth_handler))
}
pub fn reconnect_strategy(
self,
reconnect_strategy: ReconnectStrategy,
) -> UnixSocketClientBuilder<T> {
UnixSocketClientBuilder(self.0.reconnect_strategy(reconnect_strategy))
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self(self.0.timeout(timeout))
}

@ -1,4 +1,4 @@
use crate::client::{Client, ClientBuilder};
use crate::client::{Client, ClientBuilder, ReconnectStrategy};
use crate::common::{authentication::AuthHandler, WindowsPipeTransport};
use serde::{de::DeserializeOwned, Serialize};
use std::ffi::{OsStr, OsString};
@ -18,6 +18,10 @@ impl<T> WindowsPipeClientBuilder<T> {
}
}
pub fn reconnect_strategy(self, reconnect_strategy: ReconnectStrategy) -> WindowsPipeClientBuilder<T> {
WindowsPipeClientBuilder(self.0.reconnect_strategy(reconnect_strategy))
}
/// If true, will connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}`; otherwise, will connect using the address verbatim.
pub fn local(self, local: bool) -> Self {

@ -0,0 +1,183 @@
use super::Reconnectable;
use std::io;
use std::time::Duration;
/// Represents the strategy to apply when attempting to reconnect the client to the server.
#[derive(Clone, Debug)]
pub enum ReconnectStrategy {
/// A retry strategy driven by exponential back-off.
ExponentialBackoff {
/// Represents the initial time to wait between reconnect attempts.
base: Duration,
/// Factor to use when modifying the retry time, used as a multiplier.
factor: f64,
/// Represents the maximum duration to wait between attempts. None indicates no limit.
max_duration: Option<Duration>,
/// Represents the maximum attempts to retry before failing. None indicates no limit.
max_retries: Option<usize>,
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
timeout: Option<Duration>,
},
/// A retry strategy driven by the fibonacci series.
FibonacciBackoff {
/// Represents the initial time to wait between reconnect attempts.
base: Duration,
/// Represents the maximum duration to wait between attempts. None indicates no limit.
max_duration: Option<Duration>,
/// Represents the maximum attempts to retry before failing. None indicates no limit.
max_retries: Option<usize>,
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
timeout: Option<Duration>,
},
/// A retry strategy driven by a fixed interval.
FixedInterval {
/// Represents the time between reconnect attempts.
interval: Duration,
/// Represents the maximum attempts to retry before failing. None indicates no limit.
max_retries: Option<usize>,
/// Represents the maximum time to wait for a reconnect attempt. None indicates no limit.
timeout: Option<Duration>,
},
}
impl Default for ReconnectStrategy {
/// Creates a default strategy using exponential backoff logic starting from 1 second with
/// a factor of 2, a maximum duration of 30 seconds, a maximum retry count of 10, and a timeout
/// of 5 minutes per attempt.
fn default() -> Self {
Self::ExponentialBackoff {
base: Duration::from_millis(1000),
factor: 2.0,
max_duration: Some(Duration::from_secs(30)),
max_retries: Some(10),
timeout: Some(Duration::from_secs(60 * 5)),
}
}
}
impl ReconnectStrategy {
pub async fn reconnect<T: Reconnectable>(&mut self, reconnectable: &mut T) -> io::Result<()> {
// Keep track of last sleep length for use in adjustment
let mut previous_sleep = None;
let mut current_sleep = self.initial_sleep_duration();
// Keep track of remaining retries
let mut retries_remaining = self.max_retries();
// Get timeout if strategy will employ one
let timeout = self.timeout();
// Get maximum allowed duration between attempts
let max_duration = self.max_duration();
// Continue trying to reconnect while we have more tries remaining, otherwise
// we will return the last error encountered
let mut result = Ok(());
while retries_remaining.is_none() || retries_remaining > Some(0) {
// Perform reconnect attempt
result = match timeout {
Some(timeout) => {
match tokio::time::timeout(timeout, reconnectable.reconnect()).await {
Ok(x) => x,
Err(x) => Err(x.into()),
}
}
None => reconnectable.reconnect().await,
};
// If reconnect was successful, we're done and we can exit
if result.is_ok() {
return Ok(());
}
// Decrement remaining retries if we have a limit
if let Some(remaining) = retries_remaining.as_mut() {
if *remaining > 0 {
*remaining -= 1;
}
}
// Sleep before making next attempt
tokio::time::sleep(current_sleep).await;
// Update our sleep duration
let next_sleep = self.adjust_sleep(previous_sleep, current_sleep);
previous_sleep = Some(current_sleep);
current_sleep = if let Some(duration) = max_duration {
std::cmp::min(next_sleep, duration)
} else {
next_sleep
};
}
result
}
/// Returns the maximum duration between reconnect attempts, or None if there is no limit.
pub fn max_duration(&self) -> Option<Duration> {
match self {
ReconnectStrategy::ExponentialBackoff { max_duration, .. } => *max_duration,
ReconnectStrategy::FibonacciBackoff { max_duration, .. } => *max_duration,
ReconnectStrategy::FixedInterval { .. } => None,
}
}
/// Returns the maximum reconnect attempts the strategy will perform, or None if will attempt
/// forever.
pub fn max_retries(&self) -> Option<usize> {
match self {
ReconnectStrategy::ExponentialBackoff { max_retries, .. } => *max_retries,
ReconnectStrategy::FibonacciBackoff { max_retries, .. } => *max_retries,
ReconnectStrategy::FixedInterval { max_retries, .. } => *max_retries,
}
}
/// Returns the timeout per reconnect attempt that is associated with the strategy.
pub fn timeout(&self) -> Option<Duration> {
match self {
ReconnectStrategy::ExponentialBackoff { timeout, .. } => *timeout,
ReconnectStrategy::FibonacciBackoff { timeout, .. } => *timeout,
ReconnectStrategy::FixedInterval { timeout, .. } => *timeout,
}
}
/// Returns the initial duration to sleep.
fn initial_sleep_duration(&self) -> Duration {
match self {
ReconnectStrategy::ExponentialBackoff { base, .. } => *base,
ReconnectStrategy::FibonacciBackoff { base, .. } => *base,
ReconnectStrategy::FixedInterval { interval, .. } => *interval,
}
}
/// Adjusts next sleep duration based on the strategy.
fn adjust_sleep(&self, prev: Option<Duration>, curr: Duration) -> Duration {
match self {
ReconnectStrategy::ExponentialBackoff { factor, .. } => {
let next_millis = (curr.as_millis() as f64) * factor;
Duration::from_millis(if next_millis > (std::u64::MAX as f64) {
std::u64::MAX
} else {
next_millis as u64
})
}
ReconnectStrategy::FibonacciBackoff { .. } => {
let prev = prev.unwrap_or_else(|| Duration::new(0, 0));
prev.checked_add(curr).unwrap_or(Duration::MAX)
}
ReconnectStrategy::FixedInterval { .. } => curr,
}
}
}
Loading…
Cancel
Save