use super::data::{ ConnectionId, ConnectionInfo, ConnectionList, Destination, ManagerCapabilities, ManagerRequest, ManagerResponse, }; use crate::{ DistantChannel, DistantClient, DistantMsg, DistantRequestData, DistantResponseData, Map, }; use distant_net::{ router, Auth, AuthServer, Client, IntoSplit, MpscTransport, OneshotListener, Request, Response, ServerExt, ServerRef, UntypedTransportRead, UntypedTransportWrite, }; use log::*; use std::{ collections::HashMap, io, ops::{Deref, DerefMut}, }; use tokio::task::JoinHandle; mod config; pub use config::*; mod ext; pub use ext::*; router!(DistantManagerClientRouter { auth_transport: Request => Response, manager_transport: Response => Request, }); /// Represents a client that can connect to a remote distant manager pub struct DistantManagerClient { auth: Box, client: Client, distant_clients: HashMap, } impl Drop for DistantManagerClient { fn drop(&mut self) { self.auth.abort(); self.client.abort(); } } /// Represents a raw channel between a manager client and some remote server pub struct RawDistantChannel { pub transport: MpscTransport< Request>, Response>, >, forward_task: JoinHandle<()>, mailbox_task: JoinHandle<()>, } impl RawDistantChannel { pub fn abort(&self) { self.forward_task.abort(); self.mailbox_task.abort(); } } impl Deref for RawDistantChannel { type Target = MpscTransport< Request>, Response>, >; fn deref(&self) -> &Self::Target { &self.transport } } impl DerefMut for RawDistantChannel { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.transport } } struct ClientHandle { client: DistantClient, forward_task: JoinHandle<()>, mailbox_task: JoinHandle<()>, } impl Drop for ClientHandle { fn drop(&mut self) { self.forward_task.abort(); self.mailbox_task.abort(); } } impl DistantManagerClient { /// Initializes a client using the provided [`UntypedTransport`] pub fn new(config: DistantManagerClientConfig, transport: T) -> io::Result where T: IntoSplit + 'static, T::Read: UntypedTransportRead + 'static, T::Write: UntypedTransportWrite + 'static, { let DistantManagerClientRouter { auth_transport, manager_transport, .. } = DistantManagerClientRouter::new(transport); // Initialize our client with manager request/response transport let (writer, reader) = manager_transport.into_split(); let client = Client::new(writer, reader)?; // Initialize our auth handler with auth/auth transport let auth = AuthServer { on_challenge: config.on_challenge, on_verify: config.on_verify, on_info: config.on_info, on_error: config.on_error, } .start(OneshotListener::from_value(auth_transport.into_split()))?; Ok(Self { auth, client, distant_clients: HashMap::new(), }) } /// Request that the manager launches a new server at the given `destination` /// with `options` being passed for destination-specific details, returning the new /// `destination` of the spawned server to connect to pub async fn launch( &mut self, destination: impl Into, options: impl Into, ) -> io::Result { let destination = Box::new(destination.into()); let options = options.into(); trace!("launch({}, {})", destination, options); let res = self .client .send(ManagerRequest::Launch { destination, options, }) .await?; match res.payload { ManagerResponse::Launched { destination } => Ok(destination), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } /// Request that the manager establishes a new connection at the given `destination` /// with `options` being passed for destination-specific details pub async fn connect( &mut self, destination: impl Into, options: impl Into, ) -> io::Result { let destination = Box::new(destination.into()); let options = options.into(); trace!("connect({}, {})", destination, options); let res = self .client .send(ManagerRequest::Connect { destination, options, }) .await?; match res.payload { ManagerResponse::Connected { id } => Ok(id), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } /// Establishes a channel with the server represented by the `connection_id`, /// returning a [`DistantChannel`] acting as the connection /// /// ### Note /// /// Multiple calls to open a channel against the same connection will result in /// clones of the same [`DistantChannel`] rather than establishing a duplicate /// remote connection to the same server pub async fn open_channel( &mut self, connection_id: ConnectionId, ) -> io::Result { trace!("open_channel({})", connection_id); if let Some(handle) = self.distant_clients.get(&connection_id) { Ok(handle.client.clone_channel()) } else { let RawDistantChannel { transport, forward_task, mailbox_task, } = self.open_raw_channel(connection_id).await?; let (writer, reader) = transport.into_split(); let client = DistantClient::new(writer, reader)?; let channel = client.clone_channel(); self.distant_clients.insert( connection_id, ClientHandle { client, forward_task, mailbox_task, }, ); Ok(channel) } } /// Establishes a channel with the server represented by the `connection_id`, /// returning a [`Transport`] acting as the connection /// /// ### Note /// /// Multiple calls to open a channel against the same connection will result in establishing a /// duplicate remote connections to the same server, so take care when using this method pub async fn open_raw_channel( &mut self, connection_id: ConnectionId, ) -> io::Result { trace!("open_raw_channel({})", connection_id); let mut mailbox = self .client .mail(ManagerRequest::OpenChannel { id: connection_id }) .await?; // Wait for the first response, which should be channel confirmation let channel_id = match mailbox.next().await { Some(response) => match response.payload { ManagerResponse::ChannelOpened { id } => Ok(id), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), }, None => Err(io::Error::new( io::ErrorKind::ConnectionAborted, "open_channel mailbox aborted", )), }?; // Spawn reader and writer tasks to forward requests and replies // using our opened channel let (t1, t2) = MpscTransport::pair(1); let (mut writer, mut reader) = t1.into_split(); let mailbox_task = tokio::spawn(async move { use distant_net::TypedAsyncWrite; while let Some(response) = mailbox.next().await { match response.payload { ManagerResponse::Channel { response, .. } => { if let Err(x) = writer.write(response).await { error!("[Conn {}] {}", connection_id, x); } } ManagerResponse::ChannelClosed { .. } => break, _ => continue, } } }); let mut manager_channel = self.client.clone_channel(); let forward_task = tokio::spawn(async move { use distant_net::TypedAsyncRead; loop { match reader.read().await { Ok(Some(request)) => { // NOTE: In this situation, we do not expect a response to this // request (even if the server sends something back) if let Err(x) = manager_channel .fire(ManagerRequest::Channel { id: channel_id, request, }) .await { error!("[Conn {}] {}", connection_id, x); } } Ok(None) => break, Err(x) => { error!("[Conn {}] {}", connection_id, x); continue; } } } }); Ok(RawDistantChannel { transport: t2, forward_task, mailbox_task, }) } /// Retrieves a list of supported capabilities pub async fn capabilities(&mut self) -> io::Result { trace!("capabilities()"); let res = self.client.send(ManagerRequest::Capabilities).await?; match res.payload { ManagerResponse::Capabilities { supported } => Ok(supported), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } /// Retrieves information about a specific connection pub async fn info(&mut self, id: ConnectionId) -> io::Result { trace!("info({})", id); let res = self.client.send(ManagerRequest::Info { id }).await?; match res.payload { ManagerResponse::Info(info) => Ok(info), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } /// Kills the specified connection pub async fn kill(&mut self, id: ConnectionId) -> io::Result<()> { trace!("kill({})", id); let res = self.client.send(ManagerRequest::Kill { id }).await?; match res.payload { ManagerResponse::Killed => Ok(()), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } /// Retrieves a list of active connections pub async fn list(&mut self) -> io::Result { trace!("list()"); let res = self.client.send(ManagerRequest::List).await?; match res.payload { ManagerResponse::List(list) => Ok(list), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } /// Requests that the manager shuts down pub async fn shutdown(&mut self) -> io::Result<()> { trace!("shutdown()"); let res = self.client.send(ManagerRequest::Shutdown).await?; match res.payload { ManagerResponse::Shutdown => Ok(()), ManagerResponse::Error(x) => Err(x.into()), x => Err(io::Error::new( io::ErrorKind::InvalidData, format!("Got unexpected response: {:?}", x), )), } } } #[cfg(test)] mod tests { use super::*; use crate::data::{Error, ErrorKind}; use distant_net::{ FramedTransport, InmemoryTransport, PlainCodec, UntypedTransportRead, UntypedTransportWrite, }; fn setup() -> ( DistantManagerClient, FramedTransport, ) { let (t1, t2) = FramedTransport::pair(100); let client = DistantManagerClient::new(DistantManagerClientConfig::with_empty_prompts(), t1) .unwrap(); (client, t2) } #[inline] fn test_error() -> Error { Error { kind: ErrorKind::Interrupted, description: "test error".to_string(), } } #[inline] fn test_io_error() -> io::Error { test_error().into() } #[tokio::test] async fn connect_should_report_error_if_receives_error_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Error(test_error()), )) .await .unwrap(); }); let err = client .connect( "scheme://host".parse::().unwrap(), "key=value".parse::().unwrap(), ) .await .unwrap_err(); assert_eq!(err.kind(), test_io_error().kind()); assert_eq!(err.to_string(), test_io_error().to_string()); } #[tokio::test] async fn connect_should_report_error_if_receives_unexpected_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new(request.id, ManagerResponse::Shutdown)) .await .unwrap(); }); let err = client .connect( "scheme://host".parse::().unwrap(), "key=value".parse::().unwrap(), ) .await .unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::InvalidData); } #[tokio::test] async fn connect_should_return_id_from_successful_response() { let (mut client, mut transport) = setup(); let expected_id = 999; tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Connected { id: expected_id }, )) .await .unwrap(); }); let id = client .connect( "scheme://host".parse::().unwrap(), "key=value".parse::().unwrap(), ) .await .unwrap(); assert_eq!(id, expected_id); } #[tokio::test] async fn info_should_report_error_if_receives_error_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Error(test_error()), )) .await .unwrap(); }); let err = client.info(123).await.unwrap_err(); assert_eq!(err.kind(), test_io_error().kind()); assert_eq!(err.to_string(), test_io_error().to_string()); } #[tokio::test] async fn info_should_report_error_if_receives_unexpected_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new(request.id, ManagerResponse::Shutdown)) .await .unwrap(); }); let err = client.info(123).await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::InvalidData); } #[tokio::test] async fn info_should_return_connection_info_from_successful_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); let info = ConnectionInfo { id: 123, destination: "scheme://host".parse::().unwrap(), options: "key=value".parse::().unwrap(), }; transport .write(Response::new(request.id, ManagerResponse::Info(info))) .await .unwrap(); }); let info = client.info(123).await.unwrap(); assert_eq!(info.id, 123); assert_eq!( info.destination, "scheme://host".parse::().unwrap() ); assert_eq!(info.options, "key=value".parse::().unwrap()); } #[tokio::test] async fn list_should_report_error_if_receives_error_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Error(test_error()), )) .await .unwrap(); }); let err = client.list().await.unwrap_err(); assert_eq!(err.kind(), test_io_error().kind()); assert_eq!(err.to_string(), test_io_error().to_string()); } #[tokio::test] async fn list_should_report_error_if_receives_unexpected_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new(request.id, ManagerResponse::Shutdown)) .await .unwrap(); }); let err = client.list().await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::InvalidData); } #[tokio::test] async fn list_should_return_connection_list_from_successful_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); let mut list = ConnectionList::new(); list.insert(123, "scheme://host".parse::().unwrap()); transport .write(Response::new(request.id, ManagerResponse::List(list))) .await .unwrap(); }); let list = client.list().await.unwrap(); assert_eq!(list.len(), 1); assert_eq!( list.get(&123).expect("Connection list missing item"), &"scheme://host".parse::().unwrap() ); } #[tokio::test] async fn kill_should_report_error_if_receives_error_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Error(test_error()), )) .await .unwrap(); }); let err = client.kill(123).await.unwrap_err(); assert_eq!(err.kind(), test_io_error().kind()); assert_eq!(err.to_string(), test_io_error().to_string()); } #[tokio::test] async fn kill_should_report_error_if_receives_unexpected_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new(request.id, ManagerResponse::Shutdown)) .await .unwrap(); }); let err = client.kill(123).await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::InvalidData); } #[tokio::test] async fn kill_should_return_success_from_successful_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new(request.id, ManagerResponse::Killed)) .await .unwrap(); }); client.kill(123).await.unwrap(); } #[tokio::test] async fn shutdown_should_report_error_if_receives_error_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Connected { id: 0 }, )) .await .unwrap(); }); let err = client.shutdown().await.unwrap_err(); assert_eq!(err.kind(), io::ErrorKind::InvalidData); } #[tokio::test] async fn shutdown_should_report_error_if_receives_unexpected_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new( request.id, ManagerResponse::Error(test_error()), )) .await .unwrap(); }); let err = client.shutdown().await.unwrap_err(); assert_eq!(err.kind(), test_io_error().kind()); assert_eq!(err.to_string(), test_io_error().to_string()); } #[tokio::test] async fn shutdown_should_return_success_from_successful_response() { let (mut client, mut transport) = setup(); tokio::spawn(async move { let request = transport .read::>() .await .unwrap() .unwrap(); transport .write(Response::new(request.id, ManagerResponse::Shutdown)) .await .unwrap(); }); client.shutdown().await.unwrap(); } }