From e304e6a6893b75e5529a0e9e1ebacf83892f3bc6 Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Fri, 14 Jul 2023 18:54:22 -0500 Subject: [PATCH] Fix shutting down killed connections from a manager --- CHANGELOG.md | 6 +++ distant-net/src/manager/server.rs | 20 +++++++- distant-net/src/manager/server/connection.rs | 49 ++++++++++++++++++-- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d56929..df07ac2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- When terminating a connection using `distant manager kill`, the connection is + now properly dropped, resulting servers waiting to terminate due to + `--shutdown lonely=N` to now shutdown accordingly + ## [0.20.0-alpha.13] ### Added diff --git a/distant-net/src/manager/server.rs b/distant-net/src/manager/server.rs index bf9c504..9b23121 100644 --- a/distant-net/src/manager/server.rs +++ b/distant-net/src/manager/server.rs @@ -175,7 +175,25 @@ impl ManagerServer { /// Kills the connection to the server with the specified `id` async fn kill(&self, id: ConnectionId) -> io::Result<()> { match self.connections.write().await.remove(&id) { - Some(_) => Ok(()), + Some(connection) => { + // Close any open channels + if let Ok(ids) = connection.channel_ids().await { + let mut channels_lock = self.channels.write().await; + for id in ids { + if let Some(channel) = channels_lock.remove(&id) { + if let Err(x) = channel.close() { + error!("[Conn {id}] {x}"); + } + } + } + } + + // Make sure the connection is aborted so nothing new can happen + debug!("[Conn {id}] Aborting"); + connection.abort(); + + Ok(()) + } None => Err(io::Error::new( io::ErrorKind::NotConnected, "No connection found", diff --git a/distant-net/src/manager/server/connection.rs b/distant-net/src/manager/server/connection.rs index 2aaf921..3e9229a 100644 --- a/distant-net/src/manager/server/connection.rs +++ b/distant-net/src/manager/server/connection.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::{fmt, io}; use log::*; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use crate::client::{Mailbox, UntypedClient}; @@ -62,11 +62,17 @@ impl ManagerConnection { pub async fn spawn( spawn: Destination, options: Map, - client: UntypedClient, + mut client: UntypedClient, ) -> io::Result { let connection_id = rand::random(); let (tx, rx) = mpsc::unbounded_channel(); + // NOTE: Ensure that the connection is severed when the client is dropped; otherwise, when + // the connection is terminated via aborting it or the connection being dropped, the + // connection will persist which can cause problems such as lonely shutdown of the server + // never triggering! + client.shutdown_on_drop(true); + let (request_tx, request_rx) = mpsc::unbounded_channel(); let action_task = tokio::spawn(action_task(connection_id, rx, request_tx)); let response_task = tokio::spawn(response_task( @@ -105,16 +111,41 @@ impl ManagerConnection { tx: self.tx.clone(), }) } -} -impl Drop for ManagerConnection { - fn drop(&mut self) { + pub async fn channel_ids(&self) -> io::Result> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Action::GetRegistered { cb: tx }) + .map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("channel_ids failed: {x}"), + ) + })?; + + let channel_ids = rx.await.map_err(|x| { + io::Error::new( + io::ErrorKind::BrokenPipe, + format!("channel_ids callback dropped: {x}"), + ) + })?; + Ok(channel_ids) + } + + /// Aborts the tasks used to engage with the connection. + pub fn abort(&self) { self.action_task.abort(); self.request_task.abort(); self.response_task.abort(); } } +impl Drop for ManagerConnection { + fn drop(&mut self) { + self.abort(); + } +} + enum Action { Register { id: ManagerChannelId, @@ -125,6 +156,10 @@ enum Action { id: ManagerChannelId, }, + GetRegistered { + cb: oneshot::Sender>, + }, + Read { res: UntypedResponse<'static>, }, @@ -140,6 +175,7 @@ impl fmt::Debug for Action { match self { Self::Register { id, .. } => write!(f, "Action::Register {{ id: {id}, .. }}"), Self::Unregister { id } => write!(f, "Action::Unregister {{ id: {id} }}"), + Self::GetRegistered { .. } => write!(f, "Action::GetRegistered {{ .. }}"), Self::Read { .. } => write!(f, "Action::Read {{ .. }}"), Self::Write { id, .. } => write!(f, "Action::Write {{ id: {id}, .. }}"), } @@ -204,6 +240,9 @@ async fn action_task( Action::Unregister { id } => { registered.remove(&id); } + Action::GetRegistered { cb } => { + let _ = cb.send(registered.keys().copied().collect()); + } Action::Read { mut res } => { // Split {channel id}_{request id} back into pieces and // update the origin id to match the request id only