pull/146/head
Chip Senkbeil 2 years ago
parent 05473de685
commit cb6e7e85cd
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -410,10 +410,7 @@ fn read_lsp_messages(input: &[u8]) -> io::Result<(Option<Vec<u8>>, Vec<LspMsg>)>
mod tests {
use super::*;
use crate::data::{DistantRequestData, DistantResponseData};
use distant_net::{
Client, FramedTransport, InmemoryTransport, IntoSplit, PlainCodec, Request, Response,
TypedAsyncRead, TypedAsyncWrite,
};
use distant_net::{Client, FramedTransport, InmemoryTransport, PlainCodec, Request, Response};
use std::{future::Future, time::Duration};
/// Timeout used with timeout function
@ -425,8 +422,7 @@ mod tests {
RemoteLspProcess,
) {
let (mut t1, t2) = FramedTransport::pair(100);
let (writer, reader) = t2.into_split();
let session = Client::new(writer, reader).unwrap();
let session = Client::new(t2);
let spawn_task = tokio::spawn(async move {
RemoteLspCommand::new()
.spawn(session.clone_channel(), String::from("cmd arg"))
@ -520,7 +516,7 @@ mod tests {
tokio::task::yield_now().await;
let result = timeout(
TIMEOUT,
TypedAsyncRead::<Request<DistantRequestData>>::read(&mut transport),
transport.read_frame_as::<Request<DistantRequestData>>(),
)
.await;
assert!(result.is_err(), "Unexpectedly got data: {:?}", result);

@ -3,9 +3,11 @@ use crate::{
manager::data::{ChannelId, ConnectionId, Destination},
DistantMsg, DistantRequestData, DistantResponseData, ManagerResponse,
};
use distant_net::{Request, Response, ServerReply};
use distant_net::{
FramedTransport, Interest, Request, ServerReply, Transport, UntypedRequest, UntypedResponse,
};
use log::*;
use std::{collections::HashMap, io};
use std::{collections::HashMap, io, time::Duration};
use tokio::{sync::mpsc, task::JoinHandle};
/// Represents a connection a distant manager has with some distant-compatible server
@ -13,15 +15,15 @@ pub struct DistantManagerConnection {
pub id: ConnectionId,
pub destination: Destination,
pub options: Map,
tx: mpsc::Sender<StateMachine>,
reader_task: JoinHandle<()>,
writer_task: JoinHandle<()>,
tx: mpsc::Sender<Action>,
transport_task: JoinHandle<()>,
action_task: JoinHandle<()>,
}
#[derive(Clone)]
pub struct DistantManagerChannel {
channel_id: ChannelId,
tx: mpsc::Sender<StateMachine>,
tx: mpsc::Sender<Action>,
}
impl DistantManagerChannel {
@ -32,9 +34,9 @@ impl DistantManagerChannel {
pub async fn send(&self, request: Request<DistantMsg<DistantRequestData>>) -> io::Result<()> {
let channel_id = self.channel_id;
self.tx
.send(StateMachine::Write {
.send(Action::Write {
id: channel_id,
request,
data: request.to_vec()?,
})
.await
.map_err(|x| {
@ -48,7 +50,7 @@ impl DistantManagerChannel {
pub async fn close(&self) -> io::Result<()> {
let channel_id = self.channel_id;
self.tx
.send(StateMachine::Unregister { id: channel_id })
.send(Action::Unregister { id: channel_id })
.await
.map_err(|x| {
io::Error::new(
@ -59,111 +61,22 @@ impl DistantManagerChannel {
}
}
enum StateMachine {
Register {
id: ChannelId,
reply: ServerReply<ManagerResponse>,
},
Unregister {
id: ChannelId,
},
Read {
response: Response<DistantMsg<DistantResponseData>>,
},
Write {
id: ChannelId,
request: Request<DistantMsg<DistantRequestData>>,
},
}
impl DistantManagerConnection {
pub fn new(
pub fn new<T: Transport>(
destination: Destination,
options: Map,
mut writer: BoxedDistantWriter,
mut reader: BoxedDistantReader,
transport: FramedTransport<T>,
) -> Self {
let connection_id = rand::random();
let (tx, mut rx) = mpsc::channel(1);
let reader_task = {
let tx = tx.clone();
tokio::spawn(async move {
loop {
match reader.read().await {
Ok(Some(response)) => {
if tx.send(StateMachine::Read { response }).await.is_err() {
break;
}
}
Ok(None) => break,
Err(x) => {
error!("[Conn {connection_id}] {x}");
continue;
}
}
}
})
};
let writer_task = tokio::spawn(async move {
let mut registered = HashMap::new();
while let Some(state_machine) = rx.recv().await {
match state_machine {
StateMachine::Register { id, reply } => {
registered.insert(id, reply);
}
StateMachine::Unregister { id } => {
registered.remove(&id);
}
StateMachine::Read { mut response } => {
// Split {channel id}_{request id} back into pieces and
// update the origin id to match the request id only
let channel_id = match response.origin_id.split_once('_') {
Some((cid_str, oid_str)) => {
if let Ok(cid) = cid_str.parse::<ChannelId>() {
response.origin_id = oid_str.to_string();
cid
} else {
continue;
}
}
None => continue,
};
if let Some(reply) = registered.get(&channel_id) {
let response = ManagerResponse::Channel {
id: channel_id,
response,
};
if let Err(x) = reply.send(response).await {
error!("[Conn {connection_id}] {x}");
}
}
}
StateMachine::Write { id, request } => {
// Combine channel id with request id so we can properly forward
// the response containing this in the origin id
let request = Request {
id: format!("{id}_{}", request.id),
payload: request.payload,
};
if let Err(x) = writer.write(request).await {
error!("[Conn {connection_id}] {x}");
}
}
}
}
});
Self {
id: connection_id,
destination,
options,
tx,
reader_task,
writer_task,
transport_task,
action_task,
}
}
@ -173,7 +86,7 @@ impl DistantManagerConnection {
) -> io::Result<DistantManagerChannel> {
let channel_id = rand::random();
self.tx
.send(StateMachine::Register {
.send(Action::Register {
id: channel_id,
reply,
})
@ -193,7 +106,189 @@ impl DistantManagerConnection {
impl Drop for DistantManagerConnection {
fn drop(&mut self) {
self.reader_task.abort();
self.writer_task.abort();
self.transport_task.abort();
self.action_task.abort();
}
}
enum Action {
Register {
id: ChannelId,
reply: ServerReply<ManagerResponse>,
},
Unregister {
id: ChannelId,
},
Read {
data: Vec<u8>,
},
Write {
id: ChannelId,
data: Vec<u8>,
},
}
/// Internal task to read and write from a [`Transport`].
///
/// * `id` - the id of the connection.
/// * `transport` - the fully-authenticated transport.
/// * `rx` - used to receive outgoing data to send through the connection.
/// * `tx` - used to send new [`Action`]s to process.
async fn transport_task<T>(
id: ConnectionId,
transport: FramedTransport<T>,
mut rx: mpsc::UnboundedReceiver<Vec<u8>>,
mut tx: mpsc::UnboundedSender<Action>,
sleep_duration: Duration,
) {
loop {
let ready = match transport
.ready(Interest::READABLE | Interest::WRITABLE)
.await
{
Ok(ready) => ready,
Err(x) => {
error!("[Conn {id}] Querying ready status failed: {x}");
break;
}
};
// Keep track of whether we read or wrote anything
let mut read_blocked = !ready.is_readable();
let mut write_blocked = !ready.is_writable();
// If transport is readable, attempt to read a frame and forward it to our action task
if ready.is_readable() {
match transport.try_read_frame() {
Ok(Some(frame)) => {
if let Err(x) = tx
.send(Action::Read {
data: frame.into_item().into_owned(),
})
.await
{
error!("[Conn {id}] Failed to forward frame: {x}");
}
}
Ok(None) => {
debug!("[Conn {id}] Connection closed");
break;
}
Err(x) if x.kind() == io::ErrorKind::WouldBlock => read_blocked = true,
Err(x) => {
error!("[Conn {id}] {x}");
}
}
}
// If transport is writable, check if we have something to write
if ready.is_writable() {
if let Ok(data) = rx.try_recv() {
match transport.try_write_frame(data) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => error!("[Conn {id}] Send failed: {x}"),
}
} else {
// In the case of flushing, there are two scenarios in which we want to
// mark no write occurring:
//
// 1. When flush did not write any bytes, which can happen when the buffer
// is empty
// 2. When the call to write bytes blocks
match transport.try_flush() {
Ok(0) => write_blocked = true,
Ok(_) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => write_blocked = true,
Err(x) => {
error!("[Conn {id}] {x}");
}
}
}
}
// If we did not read or write anything, sleep a bit to offload CPU usage
if read_blocked && write_blocked {
tokio::time::sleep(sleep_duration).await;
}
}
}
/// Internal task to process [`Action`] items.
///
/// * `id` - the id of the connection.
/// * `rx` - used to receive new [`Action`]s to process.
/// * `tx` - used to send outgoing data through the connection.
async fn action_task(
id: ConnectionId,
mut rx: mpsc::UnboundedReceiver<Action>,
mut tx: mpsc::UnboundedSender<Vec<u8>>,
) {
let mut registered = HashMap::new();
while let Some(action) = rx.recv().await {
match action {
Action::Register { id, reply } => {
registered.insert(id, reply);
}
Action::Unregister { id } => {
registered.remove(&id);
}
Action::Read { data } => {
// Partially parse data into a request so we can modify the origin id
let mut response = match UntypedResponse::from_slice(&data) {
Ok(response) => response,
Err(x) => {
error!("[Conn {id}] Failed to parse response during read: {x}");
continue;
}
};
// Split {channel id}_{request id} back into pieces and
// update the origin id to match the request id only
let channel_id = match response.origin_id.split_once('_') {
Some((cid_str, oid_str)) => {
if let Ok(cid) = cid_str.parse::<ChannelId>() {
response.set_origin_id(oid_str);
cid
} else {
continue;
}
}
None => continue,
};
if let Some(reply) = registered.get(&channel_id) {
let response = ManagerResponse::Channel {
id: channel_id,
data: response.to_bytes(),
};
if let Err(x) = reply.send(response).await {
error!("[Conn {id}] {x}");
}
}
}
Action::Write { id, data } => {
// Partially parse data into a request so we can modify the id
let mut request = match UntypedRequest::from_slice(&data) {
Ok(request) => request,
Err(x) => {
error!("[Conn {id}] Failed to parse request during write: {x}");
continue;
}
};
// Combine channel id with request id so we can properly forward
// the response containing this in the origin id
request.set_id(format!("{id}_{}", request.id));
if let Err(x) = tx.send(request.to_bytes()).await {
error!("[Conn {id}] {x}");
}
}
}
}
}

@ -1,9 +1,6 @@
use crate::{
data::Map, manager::data::Destination, DistantClient, DistantMsg, DistantRequestData,
DistantResponseData,
};
use crate::{data::Map, manager::data::Destination};
use async_trait::async_trait;
use distant_net::{auth::Authenticator, Request, Response};
use distant_net::{auth::Authenticator, FramedTransport, Transport};
use std::{future::Future, io};
pub type BoxedLaunchHandler = Box<dyn LaunchHandler>;
@ -15,8 +12,7 @@ pub type BoxedConnectHandler = Box<dyn ConnectHandler>;
/// * `destination` is the location where the server will be started.
/// * `options` is provided to include extra information needed to launch or establish the
/// connection.
/// * `authenticator` is provided to support a challenge-based authentication while launching or
/// connecting.
/// * `authenticator` is provided to support a challenge-based authentication while launching.
///
/// Returns a [`Destination`] representing the new origin to use if a connection is desired.
#[async_trait]
@ -59,21 +55,21 @@ pub trait ConnectHandler: Send + Sync {
destination: &Destination,
options: &Map,
authenticator: &mut dyn Authenticator,
) -> io::Result<BoxedDistantWriterReader>;
) -> io::Result<FramedTransport<Box<dyn Transport>>>;
}
#[async_trait]
impl<F, R> ConnectHandler for F
where
F: for<'a> Fn(&'a Destination, &'a Map, &'a mut dyn Authenticator) -> R + Send + Sync + 'static,
R: Future<Output = io::Result<BoxedDistantWriterReader>> + Send + 'static,
R: Future<Output = io::Result<FramedTransport<Box<dyn Transport>>>> + Send + 'static,
{
async fn connect(
&self,
destination: &Destination,
options: &Map,
authenticator: &mut dyn Authenticator,
) -> io::Result<BoxedDistantWriterReader> {
) -> io::Result<FramedTransport<Box<dyn Transport>>> {
self(destination, options, authenticator).await
}
}

@ -120,6 +120,11 @@ impl<'a> UntypedRequest<'a> {
}
}
/// Updates the id of the request to the given `id`.
pub fn set_id(&mut self, id: impl Into<String>) {
self.id = Cow::Owned(id.into());
}
/// Allocates a new collection of bytes representing the request.
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![0x82];

@ -133,6 +133,16 @@ impl<'a> UntypedResponse<'a> {
}
}
/// Updates the id of the response to the given `id`.
pub fn set_id(&mut self, id: impl Into<String>) {
self.id = Cow::Owned(id.into());
}
/// Updates the origin id of the response to the given `origin_id`.
pub fn set_origin_id(&mut self, origin_id: impl Into<String>) {
self.origin_id = Cow::Owned(origin_id.into());
}
/// Allocates a new collection of bytes representing the response.
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![0x83];

Loading…
Cancel
Save