Refactor listener (#38)

* Finish implementing new listener logic

* Refactor cli tests to work with new format

* Implement tests for remote process

* Fix bugs in LSP stdout, stderr, and stdin

* Add tests for LSP remote process

* Update metadata request & response to support resolving the file type of symlinks
pull/39/head
Chip Senkbeil 3 years ago committed by GitHub
parent 22829d9cc8
commit 9bd2112344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -101,13 +101,18 @@ impl LspData {
&mut self.content
}
/// Updates the header content length based on the current content
pub fn refresh_content_length(&mut self) {
self.header.content_length = self.content.to_string().len();
}
/// Creates a session's info by inspecting the content for session parameters, removing the
/// session parameters from the content. Will also adjust the content length header to match
/// the new size of the content.
pub fn take_session_info(&mut self) -> Result<SessionInfo, LspSessionInfoError> {
match self.content.take_session_info() {
Ok(session) => {
self.header.content_length = self.content.to_string().len();
self.refresh_content_length();
Ok(session)
}
Err(x) => Err(x),

@ -1,15 +1,18 @@
use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
use crate::{client::Session, net::DataStream};
use futures::stream::{Stream, StreamExt};
use std::{
fmt::Write,
io::{self, Cursor, Read},
ops::{Deref, DerefMut},
};
use tokio::{sync::mpsc, task::JoinHandle};
mod data;
pub use data::*;
/// Represents an LSP server process on a remote machine
#[derive(Debug)]
pub struct RemoteLspProcess {
inner: RemoteProcess,
pub stdin: Option<RemoteLspStdin>,
@ -63,6 +66,7 @@ impl DerefMut for RemoteLspProcess {
}
/// A handle to a remote LSP process' standard input (stdin)
#[derive(Debug)]
pub struct RemoteLspStdin {
inner: RemoteStdin,
buf: Option<String>,
@ -75,8 +79,6 @@ impl RemoteLspStdin {
/// Writes data to the stdin of a specific remote process
pub async fn write(&mut self, data: &str) -> io::Result<()> {
let mut queue = Vec::new();
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.push_str(data),
@ -85,22 +87,14 @@ impl RemoteLspStdin {
// Read LSP messages from our internal buffer
let buf = self.buf.take().unwrap();
let mut cursor = Cursor::new(buf);
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
}
// Keep remainder of string not processed as LSP message in buffer
if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
self.buf = Some(buf);
}
let (remainder, queue) = read_lsp_messages(buf)?;
self.buf = remainder;
// Process and then send out each LSP message in our queue
for mut data in queue {
// Convert distant:// to file://
data.mut_content().convert_distant_scheme_to_local();
data.refresh_content_length();
self.inner.write(&data.to_string()).await?;
}
@ -109,183 +103,819 @@ impl RemoteLspStdin {
}
/// A handle to a remote LSP process' standard output (stdout)
#[derive(Debug)]
pub struct RemoteLspStdout {
inner: RemoteStdout,
buf: Option<String>,
read_task: JoinHandle<()>,
rx: mpsc::Receiver<io::Result<String>>,
}
impl RemoteLspStdout {
pub fn new(inner: RemoteStdout) -> Self {
Self { inner, buf: None }
let (read_task, rx) = spawn_read_task(Box::pin(futures::stream::unfold(
inner,
|mut inner| async move {
match inner.read().await {
Ok(res) => Some((res, inner)),
Err(_) => None,
}
},
)));
Self { read_task, rx }
}
pub async fn read(&mut self) -> io::Result<String> {
let mut queue = Vec::new();
let data = self.inner.read().await?;
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.push_str(&data),
None => self.buf = Some(data),
}
// Read LSP messages from our internal buffer
let buf = self.buf.take().unwrap();
let mut cursor = Cursor::new(buf);
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
}
// Keep remainder of string not processed as LSP message in buffer
if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
self.buf = Some(buf);
}
// Process and then add each LSP message as output
let mut out = String::new();
for mut data in queue {
// Convert file:// to distant://
data.mut_content().convert_local_scheme_to_distant();
write!(&mut out, "{}", data).unwrap();
}
self.rx
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
}
}
Ok(out)
impl Drop for RemoteLspStdout {
fn drop(&mut self) {
self.read_task.abort();
self.rx.close();
}
}
/// A handle to a remote LSP process' stderr
#[derive(Debug)]
pub struct RemoteLspStderr {
inner: RemoteStderr,
buf: Option<String>,
read_task: JoinHandle<()>,
rx: mpsc::Receiver<io::Result<String>>,
}
impl RemoteLspStderr {
pub fn new(inner: RemoteStderr) -> Self {
Self { inner, buf: None }
let (read_task, rx) = spawn_read_task(Box::pin(futures::stream::unfold(
inner,
|mut inner| async move {
match inner.read().await {
Ok(res) => Some((res, inner)),
Err(_) => None,
}
},
)));
Self { read_task, rx }
}
pub async fn read(&mut self) -> io::Result<String> {
let mut queue = Vec::new();
let data = self.inner.read().await?;
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.push_str(&data),
None => self.buf = Some(data),
}
self.rx
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))?
}
}
// Read LSP messages from our internal buffer
let buf = self.buf.take().unwrap();
let mut cursor = Cursor::new(buf);
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
}
impl Drop for RemoteLspStderr {
fn drop(&mut self) {
self.read_task.abort();
self.rx.close();
}
}
// Keep remainder of string not processed as LSP message in buffer
if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
self.buf = Some(buf);
fn spawn_read_task<S>(mut stream: S) -> (JoinHandle<()>, mpsc::Receiver<io::Result<String>>)
where
S: Stream<Item = String> + Send + Unpin + 'static,
{
let (tx, rx) = mpsc::channel::<io::Result<String>>(1);
let read_task = tokio::spawn(async move {
let mut task_buf: Option<String> = None;
loop {
let data = match stream.next().await {
Some(data) => data,
None => break,
};
// Create or insert into our buffer
match &mut task_buf {
Some(buf) => buf.push_str(&data),
None => task_buf = Some(data),
}
// Read LSP messages from our internal buffer
let buf = task_buf.take().unwrap();
let (remainder, queue) = match read_lsp_messages(buf) {
Ok(x) => x,
Err(x) => {
let _ = tx.send(Err(x)).await;
break;
}
};
task_buf = remainder;
// Process and then add each LSP message as output
if !queue.is_empty() {
let mut out = String::new();
for mut data in queue {
// Convert file:// to distant://
data.mut_content().convert_local_scheme_to_distant();
data.refresh_content_length();
write!(&mut out, "{}", data).unwrap();
}
if tx.send(Ok(out)).await.is_err() {
break;
}
}
}
});
// Process and then add each LSP message as output
let mut out = String::new();
for mut data in queue {
// Convert file:// to distant://
data.mut_content().convert_local_scheme_to_distant();
write!(&mut out, "{}", data).unwrap();
}
(read_task, rx)
}
Ok(out)
fn read_lsp_messages(input: String) -> io::Result<(Option<String>, Vec<LspData>)> {
let mut queue = Vec::new();
// Continue to read complete messages from the input until we either fail to parse or we reach
// end of input, resetting cursor position back to last successful parse as otherwise the
// cursor may have moved partially from lsp successfully reading the start of a message
let mut cursor = Cursor::new(input);
let mut pos = 0;
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
pos = cursor.position();
}
cursor.set_position(pos);
// Keep remainder of string not processed as LSP message in buffer
let remainder = if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
Some(buf)
} else {
None
};
Ok((remainder, queue))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
data::{Request, RequestData, Response, ResponseData},
net::{InmemoryStream, Transport},
};
use std::{future::Future, time::Duration};
/// Timeout used with timeout function
const TIMEOUT: Duration = Duration::from_millis(50);
// Configures an lsp process with a means to send & receive data from outside
async fn spawn_lsp_process() -> (Transport<InmemoryStream>, RemoteLspProcess) {
let (mut t1, t2) = Transport::make_pair();
let session = Session::initialize(t2).unwrap();
let spawn_task = tokio::spawn(RemoteLspProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = t1.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
t1.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id: rand::random() }],
))
.await
.unwrap();
// Wait for the process to be ready
let proc = spawn_task.await.unwrap().unwrap();
(t1, proc)
}
#[test]
fn stdin_write_should_only_send_out_complete_lsp_messages() {
todo!();
fn make_lsp_msg<T>(value: T) -> String
where
T: serde::Serialize,
{
let content = serde_json::to_string_pretty(&value).unwrap();
format!("Content-Length: {}\r\n\r\n{}", content.len(), content)
}
#[test]
fn stdin_write_should_support_buffering_output_until_a_complete_lsp_message_is_composed() {
// TODO: This tests that we can send part of a message and then the rest later to
// verify that this doesn't block async tasks from continuing
todo!();
async fn timeout<F, R>(duration: Duration, f: F) -> io::Result<R>
where
F: Future<Output = R>,
{
tokio::select! {
res = f => {
Ok(res)
}
_ = tokio::time::sleep(duration) => {
Err(io::Error::from(io::ErrorKind::TimedOut))
}
}
}
#[test]
fn stdin_write_should_only_consume_a_complete_lsp_message_even_if_more_is_written() {
todo!();
#[tokio::test]
async fn stdin_write_should_only_send_out_complete_lsp_messages() {
let (mut transport, mut proc) = spawn_lsp_process().await;
proc.stdin
.as_mut()
.unwrap()
.write(&make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})))
.await
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
x => panic!("Unexpected request: {:?}", x),
}
}
#[tokio::test]
async fn stdin_write_should_support_buffering_output_until_a_complete_lsp_message_is_composed()
{
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let (msg_a, msg_b) = msg.split_at(msg.len() / 2);
// Write part of the message that isn't finished
proc.stdin.as_mut().unwrap().write(msg_a).await.unwrap();
// Verify that nothing has been sent out yet
// NOTE: Yield to ensure that data would be waiting at the transport if it was sent
tokio::task::yield_now().await;
let result = timeout(TIMEOUT, transport.receive::<Request>()).await;
assert!(result.is_err(), "Unexpectedly got data: {:?}", result);
// Write remainder of message
proc.stdin.as_mut().unwrap().write(msg_b).await.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
x => panic!("Unexpected request: {:?}", x),
}
}
#[tokio::test]
async fn stdin_write_should_only_consume_a_complete_lsp_message_even_if_more_is_written() {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let extra = "Content-Length: 123";
// Write a full message plus some extra
proc.stdin
.as_mut()
.unwrap()
.write(&format!("{}{}", msg, extra))
.await
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
x => panic!("Unexpected request: {:?}", x),
}
// Also validate that the internal buffer still contains the extra
assert_eq!(
proc.stdin.unwrap().buf.unwrap(),
extra,
"Extra was not retained"
);
}
#[test]
fn stdin_write_should_support_sending_out_multiple_lsp_messages_if_all_received_at_once() {
todo!();
#[tokio::test]
async fn stdin_write_should_support_sending_out_multiple_lsp_messages_if_all_received_at_once()
{
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg_1 = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let msg_2 = make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}));
// Write two full messages at once
proc.stdin
.as_mut()
.unwrap()
.write(&format!("{}{}", msg_1, msg_2))
.await
.unwrap();
// Validate that the first outgoing req is a complete LSP message matching first
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
x => panic!("Unexpected request: {:?}", x),
}
// Validate that the second outgoing req is a complete LSP message matching second
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}))
);
}
x => panic!("Unexpected request: {:?}", x),
}
}
#[test]
fn stdin_write_should_convert_content_with_distant_scheme_to_file_scheme() {
todo!();
#[tokio::test]
async fn stdin_write_should_convert_content_with_distant_scheme_to_file_scheme() {
let (mut transport, mut proc) = spawn_lsp_process().await;
proc.stdin
.as_mut()
.unwrap()
.write(&make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "file://other/path",
})))
.await
.unwrap();
// Validate that the outgoing req is a complete LSP message
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.payload.len(), 1, "Unexpected payload size");
match &req.payload[0] {
RequestData::ProcStdin { data, .. } => {
// Verify the contents AND headers are as expected; in this case,
// this will also ensure that the Content-Length is adjusted
// when the distant scheme was changed to file
assert_eq!(
data,
&make_lsp_msg(serde_json::json!({
"field1": "file://some/path",
"field2": "file://other/path",
}))
);
}
x => panic!("Unexpected request: {:?}", x),
}
}
#[test]
fn stdout_read_should_yield_lsp_messages_as_strings() {
todo!();
#[tokio::test]
async fn stdout_read_should_yield_lsp_messages_as_strings() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})),
}],
))
.await
.unwrap();
// Receive complete message as stdout from process
let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
assert_eq!(
out,
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
#[test]
fn stdout_read_should_only_yield_complete_lsp_messages() {
// TODO: This tests that we can get an incomplete message from an inner read
// and then get the rest of the message (maybe in parts) from a later read
// to verify that this doesn't block async tasks from continuing
todo!();
#[tokio::test]
async fn stdout_read_should_only_yield_complete_lsp_messages() {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let (msg_a, msg_b) = msg.split_at(msg.len() / 2);
// Send half of LSP message over stdout
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: msg_a.to_string(),
}],
))
.await
.unwrap();
// Verify that remote process has not received a complete message yet
// NOTE: Yield to ensure that data would be waiting at the transport if it was sent
tokio::task::yield_now().await;
let result = timeout(TIMEOUT, proc.stdout.as_mut().unwrap().read()).await;
assert!(result.is_err(), "Unexpectedly got data: {:?}", result);
// Send other half of LSP message over stdout
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: msg_b.to_string(),
}],
))
.await
.unwrap();
// Receive complete message as stdout from process
let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
assert_eq!(
out,
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
#[test]
fn stdout_read_should_only_consume_a_complete_lsp_message_even_if_more_output_is_available() {
todo!();
#[tokio::test]
async fn stdout_read_should_only_consume_a_complete_lsp_message_even_if_more_output_is_available(
) {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let extra = "some extra content";
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: format!("{}{}", msg, extra),
}],
))
.await
.unwrap();
// Receive complete message as stdout from process
let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
assert_eq!(
out,
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
// Verify nothing else was sent
let result = timeout(TIMEOUT, proc.stdout.as_mut().unwrap().read()).await;
assert!(
result.is_err(),
"Unexpected extra content received on stdout"
);
}
#[test]
fn stdout_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() {
todo!();
#[tokio::test]
async fn stdout_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg_1 = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let msg_2 = make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}));
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: format!("{}{}", msg_1, msg_2),
}],
))
.await
.unwrap();
// Should send both messages back together as a single string
let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
assert_eq!(
out,
format!(
"{}{}",
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})),
make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}))
)
);
}
#[test]
fn stdout_read_should_convert_content_with_file_scheme_to_distant_scheme() {
todo!();
#[tokio::test]
async fn stdout_read_should_convert_content_with_file_scheme_to_distant_scheme() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stdout to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "file://other/path",
})),
}],
))
.await
.unwrap();
// Receive complete message as stdout from process
let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
assert_eq!(
out,
make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "distant://other/path",
}))
);
}
#[test]
fn stderr_read_should_yield_lsp_messages_as_strings() {
todo!();
#[tokio::test]
async fn stderr_read_should_yield_lsp_messages_as_strings() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})),
}],
))
.await
.unwrap();
// Receive complete message as stderr from process
let err = proc.stderr.as_mut().unwrap().read().await.unwrap();
assert_eq!(
err,
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
#[test]
fn stderr_read_should_only_yield_complete_lsp_messages() {
// TODO: This tests that we can get an incomplete message from an inner read
// and then get the rest of the message (maybe in parts) from a later read
// to verify that this doesn't block async tasks from continuing
todo!();
#[tokio::test]
async fn stderr_read_should_only_yield_complete_lsp_messages() {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let (msg_a, msg_b) = msg.split_at(msg.len() / 2);
// Send half of LSP message over stderr
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: msg_a.to_string(),
}],
))
.await
.unwrap();
// Verify that remote process has not received a complete message yet
// NOTE: Yield to ensure that data would be waiting at the transport if it was sent
tokio::task::yield_now().await;
let result = timeout(TIMEOUT, proc.stderr.as_mut().unwrap().read()).await;
assert!(result.is_err(), "Unexpectedly got data: {:?}", result);
// Send other half of LSP message over stderr
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: msg_b.to_string(),
}],
))
.await
.unwrap();
// Receive complete message as stderr from process
let err = proc.stderr.as_mut().unwrap().read().await.unwrap();
assert_eq!(
err,
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
}
#[test]
fn stderr_read_should_only_consume_a_complete_lsp_message_even_if_more_output_is_available() {
todo!();
#[tokio::test]
async fn stderr_read_should_only_consume_a_complete_lsp_message_even_if_more_errput_is_available(
) {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let extra = "some extra content";
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: format!("{}{}", msg, extra),
}],
))
.await
.unwrap();
// Receive complete message as stderr from process
let err = proc.stderr.as_mut().unwrap().read().await.unwrap();
assert_eq!(
err,
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}))
);
// Verify nothing else was sent
let result = timeout(TIMEOUT, proc.stderr.as_mut().unwrap().read()).await;
assert!(
result.is_err(),
"Unexpected extra content received on stderr"
);
}
#[test]
fn stderr_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() {
todo!();
#[tokio::test]
async fn stderr_read_should_support_yielding_multiple_lsp_messages_if_all_received_at_once() {
let (mut transport, mut proc) = spawn_lsp_process().await;
let msg_1 = make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
}));
let msg_2 = make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}));
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: format!("{}{}", msg_1, msg_2),
}],
))
.await
.unwrap();
// Should send both messages back together as a single string
let err = proc.stderr.as_mut().unwrap().read().await.unwrap();
assert_eq!(
err,
format!(
"{}{}",
make_lsp_msg(serde_json::json!({
"field1": "a",
"field2": "b",
})),
make_lsp_msg(serde_json::json!({
"field1": "c",
"field2": "d",
}))
)
);
}
#[test]
fn stderr_read_should_convert_content_with_file_scheme_to_distant_scheme() {
todo!();
#[tokio::test]
async fn stderr_read_should_convert_content_with_file_scheme_to_distant_scheme() {
let (mut transport, mut proc) = spawn_lsp_process().await;
// Send complete LSP message as stderr to process
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id: proc.id(),
data: make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "file://other/path",
})),
}],
))
.await
.unwrap();
// Receive complete message as stderr from process
let err = proc.stderr.as_mut().unwrap().read().await.unwrap();
assert_eq!(
err,
make_lsp_msg(serde_json::json!({
"field1": "distant://some/path",
"field2": "distant://other/path",
}))
);
}
}

@ -32,6 +32,7 @@ pub enum RemoteProcessError {
}
/// Represents a process on a remote machine
#[derive(Debug)]
pub struct RemoteProcess {
/// Id of the process
id: usize,
@ -151,6 +152,7 @@ impl RemoteProcess {
}
/// A handle to a remote process' standard input (stdin)
#[derive(Debug)]
pub struct RemoteStdin(mpsc::Sender<String>);
impl RemoteStdin {
@ -164,6 +166,7 @@ impl RemoteStdin {
}
/// A handle to a remote process' standard output (stdout)
#[derive(Debug)]
pub struct RemoteStdout(mpsc::Receiver<String>);
impl RemoteStdout {
@ -177,6 +180,7 @@ impl RemoteStdout {
}
/// A handle to a remote process' stderr
#[derive(Debug)]
pub struct RemoteStderr(mpsc::Receiver<String>);
impl RemoteStderr {
@ -242,8 +246,6 @@ async fn process_incoming_responses(
stderr_tx: mpsc::Sender<String>,
kill_tx: mpsc::Sender<()>,
) -> Result<(bool, Option<i32>), RemoteProcessError> {
let mut result = Err(RemoteProcessError::UnexpectedEof);
while let Some(res) = broadcast.recv().await {
// Check if any of the payload data is the termination
let exit_status = res.payload.iter().find_map(|data| match data {
@ -258,16 +260,10 @@ async fn process_incoming_responses(
for data in res.payload {
match data {
ResponseData::ProcStdout { id, data } if id == proc_id => {
if let Err(_) = stdout_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
let _ = stdout_tx.send(data).await;
}
ResponseData::ProcStderr { id, data } if id == proc_id => {
if let Err(_) = stderr_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
let _ = stderr_tx.send(data).await;
}
_ => {}
}
@ -275,85 +271,468 @@ async fn process_incoming_responses(
// If we got a termination, then exit accordingly
if let Some((success, code)) = exit_status {
result = Ok((success, code));
// Flag that the other task should conclude
let _ = kill_tx.try_send(());
break;
return Ok((success, code));
}
}
// Flag that the other task should conclude
let _ = kill_tx.try_send(());
trace!("Process incoming channel closed");
result
Err(RemoteProcessError::UnexpectedEof)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
data::{Error, ErrorKind},
net::{InmemoryStream, Transport},
};
#[test]
fn spawn_should_return_bad_response_if_payload_size_unexpected() {
todo!();
fn make_session() -> (Transport<InmemoryStream>, Session<InmemoryStream>) {
let (t1, t2) = Transport::make_pair();
(t1, Session::initialize(t2).unwrap())
}
#[test]
fn spawn_should_return_bad_response_if_did_not_get_a_indicator_that_process_started() {
todo!();
#[tokio::test]
async fn spawn_should_return_bad_response_if_payload_size_unexpected() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
transport
.send(Response::new("test-tenant", Some(req.id), Vec::new()))
.await
.unwrap();
// Get the spawn result and verify
let result = spawn_task.await.unwrap();
assert!(
matches!(result, Err(RemoteProcessError::BadResponse)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn id_should_return_randomly_generated_process_id() {
todo!();
#[tokio::test]
async fn spawn_should_return_bad_response_if_did_not_get_a_indicator_that_process_started() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::Error(Error {
kind: ErrorKind::Other,
description: String::from("some error"),
})],
))
.await
.unwrap();
// Get the spawn result and verify
let result = spawn_task.await.unwrap();
assert!(
matches!(result, Err(RemoteProcessError::BadResponse)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn wait_should_wait_for_internal_tasks_to_complete_and_return_process_exit_information() {
todo!();
}
#[tokio::test]
async fn kill_should_return_error_if_internal_tasks_already_completed() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
#[test]
fn wait_should_return_error_if_internal_tasks_fail() {
todo!();
}
// Receive the process and then abort it to make kill fail
let mut proc = spawn_task.await.unwrap().unwrap();
proc.abort();
// Ensure that the other tasks are aborted before continuing
tokio::task::yield_now().await;
#[test]
fn abort_should_abort_internal_tasks() {
todo!();
let result = proc.kill().await;
assert!(
matches!(result, Err(RemoteProcessError::ChannelDead)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn kill_should_return_error_if_internal_tasks_already_completed() {
todo!();
#[tokio::test]
async fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then kill it
let mut proc = spawn_task.await.unwrap().unwrap();
assert!(proc.kill().await.is_ok(), "Failed to send kill request");
// Verify the kill request was sent
let req = transport.receive::<Request>().await.unwrap().unwrap();
assert_eq!(
req.payload.len(),
1,
"Unexpected payload length for kill request"
);
assert_eq!(req.payload[0], RequestData::ProcKill { id });
// Verify we can no longer write to stdin anymore
assert_eq!(
proc.stdin
.as_mut()
.unwrap()
.write("some stdin")
.await
.unwrap_err()
.kind(),
io::ErrorKind::BrokenPipe
);
}
#[test]
fn kill_should_send_proc_kill_request_and_then_cause_stdin_forwarding_to_close() {
todo!();
#[tokio::test]
async fn stdin_should_be_forwarded_from_receiver_field() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then send stdin
let mut proc = spawn_task.await.unwrap().unwrap();
proc.stdin
.as_mut()
.unwrap()
.write("some input")
.await
.unwrap();
// Verify that a request is made through the session
match &transport
.receive::<Request>()
.await
.unwrap()
.unwrap()
.payload[0]
{
RequestData::ProcStdin { id, data } => {
assert_eq!(*id, 12345);
assert_eq!(data, "some input");
}
x => panic!("Unexpected request: {:?}", x),
}
}
#[test]
fn stdin_should_be_forwarded_from_receiver_field() {
todo!();
#[tokio::test]
async fn stdout_should_be_forwarded_to_receiver_field() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then read stdout
let mut proc = spawn_task.await.unwrap().unwrap();
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStdout {
id,
data: String::from("some out"),
}],
))
.await
.unwrap();
let out = proc.stdout.as_mut().unwrap().read().await.unwrap();
assert_eq!(out, "some out");
}
#[test]
fn stdout_should_be_forwarded_to_receiver_field() {
todo!();
#[tokio::test]
async fn stderr_should_be_forwarded_to_receiver_field() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then read stderr
let mut proc = spawn_task.await.unwrap().unwrap();
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcStderr {
id,
data: String::from("some err"),
}],
))
.await
.unwrap();
let out = proc.stderr.as_mut().unwrap().read().await.unwrap();
assert_eq!(out, "some err");
}
#[test]
fn stderr_should_be_forwarded_to_receiver_field() {
todo!();
#[tokio::test]
async fn wait_should_return_error_if_internal_tasks_fail() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then abort it to make internal tasks fail
let proc = spawn_task.await.unwrap().unwrap();
proc.abort();
let result = proc.wait().await;
assert!(
matches!(result, Err(RemoteProcessError::WaitFailed(_))),
"Unexpected result: {:?}",
result
);
}
#[test]
fn receiving_done_response_should_terminate_internal_tasks() {
todo!();
#[tokio::test]
async fn wait_should_return_error_if_connection_terminates_before_receiving_done_response() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then terminate session connection
let proc = spawn_task.await.unwrap().unwrap();
drop(transport);
// Ensure that the other tasks are cancelled before continuing
tokio::task::yield_now().await;
let result = proc.wait().await;
assert!(
matches!(result, Err(RemoteProcessError::UnexpectedEof)),
"Unexpected result: {:?}",
result
);
}
#[test]
fn receiving_done_response_should_result_in_wait_returning_exit_information() {
todo!();
#[tokio::test]
async fn receiving_done_response_should_result_in_wait_returning_exit_information() {
let (mut transport, session) = make_session();
// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(RemoteProcess::spawn(
String::from("test-tenant"),
session,
String::from("cmd"),
vec![String::from("arg")],
));
// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();
// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
Some(req.id),
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();
// Receive the process and then spawn a task for it to complete
let proc = spawn_task.await.unwrap().unwrap();
let proc_wait_task = tokio::spawn(proc.wait());
// Send a process completion response to pass along exit status and conclude wait
transport
.send(Response::new(
"test-tenant",
None,
vec![ResponseData::ProcDone {
id,
success: false,
code: Some(123),
}],
))
.await
.unwrap();
// Finally, verify that we complete and get the expected results
assert_eq!(proc_wait_task.await.unwrap().unwrap(), (false, Some(123)));
}
}

@ -1,5 +1,5 @@
use crate::net::{SecretKey, UnprotectedToHexKey};
use derive_more::{Display, Error};
use orion::aead::SecretKey;
use std::{
env,
net::{IpAddr, SocketAddr},
@ -141,11 +141,6 @@ impl SessionInfo {
Ok(SocketAddr::from((addr, self.port)))
}
/// Returns a string representing the auth key as hex
pub fn to_unprotected_hex_auth_key(&self) -> String {
hex::encode(self.auth_key.unprotected_as_bytes())
}
/// Converts to unprotected string that exposes the auth key in the form of
/// `DISTANT DATA <host> <port> <auth key>`
pub fn to_unprotected_string(&self) -> String {
@ -153,7 +148,7 @@ impl SessionInfo {
"DISTANT DATA {} {} {}",
self.host,
self.port,
self.to_unprotected_hex_auth_key()
self.auth_key.unprotected_to_hex_key()
)
}
}

@ -2,7 +2,7 @@ use crate::{
client::utils,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, Response},
net::{DataStream, InmemoryStream, SecretKey, Transport, TransportError, TransportWriteHalf},
net::{DataStream, SecretKey, Transport, TransportError, TransportWriteHalf},
};
use log::*;
use std::{
@ -41,13 +41,6 @@ where
pub broadcast: Option<mpsc::Receiver<Response>>,
}
impl Session<InmemoryStream> {
/// Creates a session around an inmemory transport
pub async fn from_inmemory_transport(transport: Transport<InmemoryStream>) -> io::Result<Self> {
Self::initialize(transport).await
}
}
impl Session<TcpStream> {
/// Connect to a remote TCP server using the provided information
pub async fn tcp_connect(info: SessionInfo) -> io::Result<Self> {
@ -61,7 +54,7 @@ impl Session<TcpStream> {
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"))
);
Self::initialize(transport).await
Self::initialize(transport)
}
/// Connect to a remote TCP server, timing out after duration has passed
@ -87,7 +80,7 @@ impl Session<tokio::net::UnixStream> {
.map(|x| format!("{:?}", x))
.unwrap_or_else(|_| String::from("???"))
);
Self::initialize(transport).await
Self::initialize(transport)
}
/// Connect to a proxy unix socket, timing out after duration has passed
@ -107,7 +100,7 @@ where
T: DataStream,
{
/// Initializes a session using the provided transport
pub async fn initialize(transport: Transport<T>) -> io::Result<Self> {
pub fn initialize(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast_tx, broadcast_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
@ -231,7 +224,7 @@ mod tests {
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).await.unwrap();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
@ -252,7 +245,7 @@ mod tests {
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).await.unwrap();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.send_timeout(req, Duration::from_millis(30)).await {
@ -267,7 +260,7 @@ mod tests {
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).await.unwrap();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.fire(req).await {

@ -15,6 +15,10 @@ pub const READ_PAUSE_MILLIS: u64 = 50;
/// Represents the length of the salt to use for encryption
pub const SALT_LEN: usize = 16;
/// Represents time in milliseconds a connection has to perform a handshake (on server side)
/// before the server discards the connection (5 minutes)
pub const CONN_HANDSHAKE_TIMEOUT_MILLIS: u64 = 300000;
/// Test-only constants
#[cfg(test)]
pub mod test {

@ -207,6 +207,10 @@ pub enum RequestData {
/// intermediate components normalized and symbolic links resolved
#[cfg_attr(feature = "structopt", structopt(short, long))]
canonicalize: bool,
/// Whether or not to follow symlinks to determine absolute file type (dir/file)
#[cfg_attr(feature = "structopt", structopt(long))]
resolve_file_type: bool,
},
/// Runs a process on the remote machine

@ -11,11 +11,12 @@ mod constants;
mod net;
pub use net::{
DataStream, InmemoryStream, InmemoryStreamReadHalf, InmemoryStreamWriteHalf, Listener,
SecretKey, Transport, TransportError, TransportReadHalf, TransportWriteHalf,
SecretKey, Transport, TransportError, TransportListener, TransportListenerCtx,
TransportReadHalf, TransportWriteHalf, UnprotectedToHexKey,
};
pub mod data;
pub use data::{Request, RequestData, Response, ResponseData};
mod server;
pub use server::{DistantServer, PortRange, RelayServer};
pub use server::{DistantServer, DistantServerOptions, PortRange, RelayServer};

@ -1,29 +1,151 @@
use super::DataStream;
use std::{future::Future, pin::Pin};
use super::{DataStream, SecretKey, Transport};
use futures::stream::Stream;
use log::*;
use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
use tokio::{
io,
net::{TcpListener, TcpStream},
sync::mpsc,
task::JoinHandle,
};
/// Represents a type that has a listen interface
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct TransportListenerCtx {
pub auth_key: Option<Arc<SecretKey>>,
pub timeout: Option<Duration>,
}
/// Represents a [`Stream`] consisting of newly-connected [`DataStream`] instances that
/// have been wrapped in [`Transport`]
pub struct TransportListener<T>
where
T: DataStream,
{
listen_task: JoinHandle<()>,
accept_task: JoinHandle<()>,
rx: mpsc::Receiver<Transport<T>>,
}
impl<T> TransportListener<T>
where
T: DataStream + Send + 'static,
{
pub fn initialize<L>(listener: L, ctx: TransportListenerCtx) -> Self
where
L: Listener<Output = T> + 'static,
{
let (stream_tx, mut stream_rx) = mpsc::channel::<T>(1);
let listen_task = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok(stream) => {
if stream_tx.send(stream).await.is_err() {
error!("Listener failed to pass along stream");
break;
}
}
Err(x) => {
error!("Listener failed to accept stream: {}", x);
break;
}
}
}
});
let TransportListenerCtx { auth_key, timeout } = ctx;
let (tx, rx) = mpsc::channel::<Transport<T>>(1);
let accept_task = tokio::spawn(async move {
// Check if we have a new connection. If so, spawn a task for it
while let Some(stream) = stream_rx.recv().await {
let auth_key = auth_key.as_ref().cloned();
let tx_2 = tx.clone();
tokio::spawn(async move {
match do_handshake(stream, auth_key, timeout).await {
Ok(transport) => {
if let Err(x) = tx_2.send(transport).await {
error!("Failed to forward transport: {}", x);
panic!("{}", x);
}
}
Err(x) => {
error!("Transport handshake failed: {}", x);
panic!("{}", x);
}
}
});
}
});
Self {
listen_task,
accept_task,
rx,
}
}
pub fn abort(&self) {
self.listen_task.abort();
self.accept_task.abort();
}
/// Waits for the next fully-initialized transport for an incoming stream to be available,
/// returning none if no longer accepting new connections
pub async fn accept(&mut self) -> Option<Transport<T>> {
self.rx.recv().await
}
/// Converts into a stream of transport-wrapped connections
pub fn into_stream(self) -> impl Stream<Item = Transport<T>> {
futures::stream::unfold(self, |mut _self| async move {
_self
.accept()
.await
.map(move |transport| (transport, _self))
})
}
}
async fn do_handshake<T>(
stream: T,
auth_key: Option<Arc<SecretKey>>,
timeout: Option<Duration>,
) -> io::Result<Transport<T>>
where
T: DataStream,
{
if let Some(timeout) = timeout {
tokio::select! {
res = Transport::from_handshake(stream, auth_key) => {
res
}
_ = tokio::time::sleep(timeout) => {
Err(io::Error::from(io::ErrorKind::TimedOut))
}
}
} else {
Transport::from_handshake(stream, auth_key).await
}
}
pub type AcceptFuture<'a, T> = Pin<Box<dyn Future<Output = io::Result<T>> + Send + 'a>>;
/// Represents a type that has a listen interface for receiving raw streams
pub trait Listener: Send + Sync {
type Conn: DataStream;
type Output;
/// Async function that accepts a new connection, returning `Ok(Self::Conn)`
/// upon receiving the next connection
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a;
}
impl Listener for TcpListener {
type Conn = TcpStream;
type Output = TcpStream;
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a,
{
async fn accept(_self: &TcpListener) -> io::Result<TcpStream> {
async fn accept<'a>(_self: &'a TcpListener) -> io::Result<TcpStream> {
_self.accept().await.map(|(stream, _)| stream)
}
@ -33,9 +155,9 @@ impl Listener for TcpListener {
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Conn = tokio::net::UnixStream;
type Output = tokio::net::UnixStream;
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a,
{
@ -48,16 +170,22 @@ impl Listener for tokio::net::UnixListener {
}
#[cfg(test)]
impl<T: DataStream + Send + Sync> Listener for tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>> {
type Conn = T;
impl<T> Listener for tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>
where
T: DataStream + Send + Sync + 'static,
{
type Output = T;
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
fn accept<'a>(&'a self) -> AcceptFuture<'a, Self::Output>
where
Self: Sync + 'a,
{
async fn accept<T>(
_self: &tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>,
) -> io::Result<T> {
async fn accept<'a, T>(
_self: &'a tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>,
) -> io::Result<T>
where
T: DataStream + Send + Sync + 'static,
{
_self
.lock()
.await

@ -1,8 +1,18 @@
mod listener;
mod transport;
pub use listener::Listener;
pub use listener::{AcceptFuture, Listener, TransportListener, TransportListenerCtx};
pub use transport::*;
// Re-export commonly-used orion structs
pub use orion::aead::SecretKey;
pub trait UnprotectedToHexKey {
fn unprotected_to_hex_key(&self) -> String;
}
impl UnprotectedToHexKey for SecretKey {
fn unprotected_to_hex_key(&self) -> String {
hex::encode(self.unprotected_as_bytes())
}
}

@ -10,6 +10,7 @@ use tokio::{
};
/// Represents a data stream comprised of two inmemory channels
#[derive(Debug)]
pub struct InmemoryStream {
incoming: InmemoryStreamReadHalf,
outgoing: InmemoryStreamWriteHalf,
@ -72,6 +73,7 @@ impl AsyncWrite for InmemoryStream {
}
/// Read portion of an inmemory channel
#[derive(Debug)]
pub struct InmemoryStreamReadHalf(mpsc::Receiver<Vec<u8>>);
impl AsyncRead for InmemoryStreamReadHalf {
@ -91,6 +93,7 @@ impl AsyncRead for InmemoryStreamReadHalf {
}
/// Write portion of an inmemory channel
#[derive(Debug)]
pub struct InmemoryStreamWriteHalf(mpsc::Sender<Vec<u8>>);
impl AsyncWrite for InmemoryStreamWriteHalf {

@ -140,6 +140,7 @@ macro_rules! recv {
}
/// Represents a transport of data across the network
#[derive(Debug)]
pub struct Transport<T>
where
T: DataStream,
@ -264,6 +265,11 @@ where
recv!(self.conn, self.crypt_key, self.auth_key).await
}
/// Returns a textual description of the transport's underlying connection
pub fn to_connection_tag(&self) -> String {
self.conn.get_ref().to_connection_tag()
}
/// Splits transport into read and write halves
pub fn into_split(self) -> (TransportReadHalf<T::Read>, TransportWriteHalf<T::Write>) {
let crypt_key = self.crypt_key;

@ -73,7 +73,11 @@ pub(super) async fn process(
RequestData::Copy { src, dst } => copy(src, dst).await,
RequestData::Rename { src, dst } => rename(src, dst).await,
RequestData::Exists { path } => exists(path).await,
RequestData::Metadata { path, canonicalize } => metadata(path, canonicalize).await,
RequestData::Metadata {
path,
canonicalize,
resolve_file_type,
} => metadata(path, canonicalize, resolve_file_type).await,
RequestData::ProcRun { cmd, args } => {
proc_run(tenant.to_string(), conn_id, state, tx, cmd, args).await
}
@ -324,14 +328,26 @@ async fn exists(path: PathBuf) -> Result<ResponseData, ServerError> {
})
}
async fn metadata(path: PathBuf, canonicalize: bool) -> Result<ResponseData, ServerError> {
async fn metadata(
path: PathBuf,
canonicalize: bool,
resolve_file_type: bool,
) -> Result<ResponseData, ServerError> {
let metadata = tokio::fs::symlink_metadata(path.as_path()).await?;
let canonicalized_path = if canonicalize {
Some(tokio::fs::canonicalize(path).await?)
Some(tokio::fs::canonicalize(path.as_path()).await?)
} else {
None
};
// If asking for resolved file type and current type is symlink, then we want to refresh
// our metadata to get the filetype for the resolved link
let file_type = if resolve_file_type && metadata.file_type().is_symlink() {
tokio::fs::metadata(path).await?.file_type()
} else {
metadata.file_type()
};
Ok(ResponseData::Metadata {
canonicalized_path,
accessed: metadata
@ -351,9 +367,9 @@ async fn metadata(path: PathBuf, canonicalize: bool) -> Result<ResponseData, Ser
.map(|d| d.as_millis()),
len: metadata.len(),
readonly: metadata.permissions().readonly(),
file_type: if metadata.is_dir() {
file_type: if file_type.is_dir() {
FileType::Dir
} else if metadata.is_file() {
} else if file_type.is_file() {
FileType::File
} else {
FileType::Symlink
@ -1801,6 +1817,7 @@ mod tests {
vec![RequestData::Metadata {
path: file.path().to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
);
@ -1827,6 +1844,7 @@ mod tests {
vec![RequestData::Metadata {
path: file.path().to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
);
@ -1862,6 +1880,7 @@ mod tests {
vec![RequestData::Metadata {
path: dir.path().to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
);
@ -1899,6 +1918,7 @@ mod tests {
vec![RequestData::Metadata {
path: symlink.path().to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
);
@ -1936,6 +1956,7 @@ mod tests {
vec![RequestData::Metadata {
path: symlink.path().to_path_buf(),
canonicalize: true,
resolve_file_type: false,
}],
);
@ -1958,6 +1979,38 @@ mod tests {
}
}
#[tokio::test]
async fn metadata_should_resolve_file_type_of_symlink_if_flag_specified() {
let (conn_id, state, tx, mut rx) = setup(1);
let temp = assert_fs::TempDir::new().unwrap();
let file = temp.child("file");
file.write_str("some text").unwrap();
let symlink = temp.child("link");
symlink.symlink_to_file(file.path()).unwrap();
let req = Request::new(
"test-tenant",
vec![RequestData::Metadata {
path: symlink.path().to_path_buf(),
canonicalize: false,
resolve_file_type: true,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::Metadata {
file_type: FileType::File,
..
} => {}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn proc_run_should_send_error_on_failure() {
let (conn_id, state, tx, mut rx) = setup(1);

@ -4,13 +4,18 @@ mod state;
use state::State;
use crate::{
constants::CONN_HANDSHAKE_TIMEOUT_MILLIS,
data::{Request, Response},
net::{DataStream, Listener, SecretKey, Transport, TransportReadHalf, TransportWriteHalf},
net::{
DataStream, SecretKey, Transport, TransportListener, TransportListenerCtx,
TransportReadHalf, TransportWriteHalf,
},
server::{
utils::{ConnTracker, ShutdownTask},
PortRange,
},
};
use futures::stream::{Stream, StreamExt};
use log::*;
use std::{net::IpAddr, sync::Arc};
use tokio::{
@ -23,79 +28,67 @@ use tokio::{
/// Represents a server that listens for requests, processes them, and sends responses
pub struct DistantServer {
port: u16,
auth_key: Arc<SecretKey>,
conn_task: JoinHandle<()>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DistantServerOptions {
pub shutdown_after: Option<Duration>,
pub max_msg_capacity: usize,
}
impl Default for DistantServerOptions {
fn default() -> Self {
Self {
shutdown_after: None,
max_msg_capacity: 1,
}
}
}
impl DistantServer {
/// Bind to an IP address and port from the given range, taking an optional shutdown duration
/// that will shutdown the server if there is no active connection after duration
pub async fn bind(
addr: IpAddr,
port: PortRange,
shutdown_after: Option<Duration>,
max_msg_capacity: usize,
) -> io::Result<Self> {
auth_key: Option<Arc<SecretKey>>,
opts: DistantServerOptions,
) -> io::Result<(Self, u16)> {
debug!("Binding to {} in range {}", addr, port);
let listener = TcpListener::bind(port.make_socket_addrs(addr).as_slice()).await?;
let port = listener.local_addr()?.port();
debug!("Bound to port: {}", port);
Ok(Self::initialize(
let stream = TransportListener::initialize(
listener,
port,
shutdown_after,
max_msg_capacity,
))
TransportListenerCtx {
auth_key,
timeout: Some(Duration::from_millis(CONN_HANDSHAKE_TIMEOUT_MILLIS)),
},
)
.into_stream();
Ok((Self::initialize(Box::pin(stream), opts), port))
}
/// Initialize a distant server using the provided listener
pub fn initialize<T, L>(
listener: L,
port: u16,
shutdown_after: Option<Duration>,
max_msg_capacity: usize,
) -> Self
pub fn initialize<T, S>(stream: S, opts: DistantServerOptions) -> Self
where
T: DataStream + Send + 'static,
L: Listener<Conn = T> + 'static,
S: Stream<Item = Transport<T>> + Send + Unpin + 'static,
{
// Build our state for the server
let state: Arc<Mutex<State>> = Arc::new(Mutex::new(State::default()));
let auth_key = Arc::new(SecretKey::default());
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
let (shutdown, tracker) = ShutdownTask::maybe_initialize(opts.shutdown_after);
// Spawn our connection task
let auth_key_2 = Arc::clone(&auth_key);
let conn_task = tokio::spawn(async move {
connection_loop(
listener,
state,
auth_key_2,
tracker,
shutdown,
max_msg_capacity,
)
.await
connection_loop(stream, state, tracker, shutdown, opts.max_msg_capacity).await
});
Self {
port,
auth_key,
conn_task,
}
}
/// Returns the port this server is bound to
pub fn port(&self) -> u16 {
self.port
}
/// Returns a string representing the auth key as hex
pub fn to_unprotected_hex_auth_key(&self) -> String {
hex::encode(self.auth_key.unprotected_as_bytes())
Self { conn_task }
}
/// Waits for the server to terminate
@ -109,32 +102,30 @@ impl DistantServer {
}
}
async fn connection_loop<T, L>(
listener: L,
async fn connection_loop<T, S>(
mut stream: S,
state: Arc<Mutex<State>>,
auth_key: Arc<SecretKey>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
shutdown: Option<ShutdownTask>,
max_msg_capacity: usize,
) where
T: DataStream,
L: Listener<Conn = T>,
T: DataStream + Send + 'static,
S: Stream<Item = Transport<T>> + Send + Unpin + 'static,
{
let inner = async move {
loop {
match listener.accept().await {
Ok(conn) => {
match stream.next().await {
Some(transport) => {
let conn_id = rand::random();
debug!(
"<Conn @ {}> Established against {}",
conn_id,
conn.to_connection_tag()
transport.to_connection_tag()
);
if let Err(x) = on_new_conn(
conn,
transport,
conn_id,
Arc::clone(&state),
Arc::clone(&auth_key),
tracker.as_ref().map(Arc::clone),
max_msg_capacity,
)
@ -143,11 +134,11 @@ async fn connection_loop<T, L>(
error!("<Conn @ {}> Failed handshake: {}", conn_id, x);
}
}
Err(x) => {
error!("Listener failed: {}", x);
None => {
info!("Listener shutting down");
break;
}
}
};
}
};
@ -165,10 +156,9 @@ async fn connection_loop<T, L>(
/// Processes a new connection, performing a handshake, and then spawning two tasks to handle
/// input and output, returning join handles for the input and output tasks respectively
async fn on_new_conn<T>(
conn: T,
transport: Transport<T>,
conn_id: usize,
state: Arc<Mutex<State>>,
auth_key: Arc<SecretKey>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
max_msg_capacity: usize,
) -> io::Result<JoinHandle<()>>
@ -180,10 +170,6 @@ where
ct.lock().await.increment();
}
// Establish a proper connection via a handshake,
// discarding the connection otherwise
let transport = Transport::from_handshake(conn, Some(auth_key)).await?;
// Split the transport into read and write halves so we can handle input
// and output concurrently
let (t_read, t_write) = transport.into_split();
@ -283,34 +269,90 @@ async fn response_loop<T>(
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wait_should_return_ok_when_all_inner_tasks_complete() {
todo!();
use crate::{
data::{RequestData, ResponseData},
net::InmemoryStream,
};
use std::pin::Pin;
fn make_transport_stream() -> (
mpsc::Sender<Transport<InmemoryStream>>,
Pin<Box<dyn Stream<Item = Transport<InmemoryStream>> + Send>>,
) {
let (tx, rx) = mpsc::channel::<Transport<InmemoryStream>>(1);
let stream = futures::stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(move |transport| (transport, rx))
});
(tx, Box::pin(stream))
}
#[test]
fn wait_should_return_error_when_server_aborted() {
todo!();
}
#[tokio::test]
async fn wait_should_return_ok_when_all_inner_tasks_complete() {
let (tx, stream) = make_transport_stream();
let server = DistantServer::initialize(stream, Default::default());
#[test]
fn abort_should_abort_inner_tasks_and_all_connections() {
todo!();
// Conclude all server tasks by closing out the listener
drop(tx);
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
#[test]
fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
todo!();
#[tokio::test]
async fn wait_should_return_error_when_server_aborted() {
let (_tx, stream) = make_transport_stream();
let server = DistantServer::initialize(stream, Default::default());
server.abort();
match server.wait().await {
Err(x) if x.is_cancelled() => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn server_shutdown_should_abort_all_connections() {
todo!();
#[tokio::test]
async fn server_should_receive_requests_and_send_responses_to_appropriate_connections() {
let (tx, stream) = make_transport_stream();
let _server = DistantServer::initialize(stream, Default::default());
// Send over a "connection"
let (mut t1, t2) = Transport::make_pair();
tx.send(t2).await.unwrap();
// Send a request
t1.send(Request::new(
"test-tenant",
vec![RequestData::SystemInfo {}],
))
.await
.unwrap();
// Get a response
let res = t1.receive::<Response>().await.unwrap().unwrap();
assert!(res.payload.len() == 1, "Unexpected payload size");
assert!(
matches!(res.payload[0], ResponseData::SystemInfo { .. }),
"Unexpected response: {:?}",
res.payload[0]
);
}
#[test]
fn server_should_execute_requests_and_return_responses() {
todo!();
#[tokio::test]
async fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
let (_tx, stream) = make_transport_stream();
let server = DistantServer::initialize(
stream,
DistantServerOptions {
shutdown_after: Some(Duration::from_millis(50)),
max_msg_capacity: 1,
},
);
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
}

@ -3,6 +3,6 @@ mod port;
mod relay;
mod utils;
pub use self::distant::DistantServer;
pub use self::distant::{DistantServer, DistantServerOptions};
pub use port::PortRange;
pub use relay::RelayServer;

@ -2,9 +2,10 @@ use crate::{
client::Session,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{DataStream, Listener, Transport, TransportReadHalf, TransportWriteHalf},
net::{DataStream, Transport, TransportReadHalf, TransportWriteHalf},
server::utils::{ConnTracker, ShutdownTask},
};
use futures::stream::{Stream, StreamExt};
use log::*;
use std::{collections::HashMap, marker::Unpin, sync::Arc};
use tokio::{
@ -24,15 +25,15 @@ pub struct RelayServer {
}
impl RelayServer {
pub fn initialize<T1, T2, L>(
pub fn initialize<T1, T2, S>(
mut session: Session<T1>,
listener: L,
mut stream: S,
shutdown_after: Option<Duration>,
) -> io::Result<Self>
where
T1: DataStream + 'static,
T2: DataStream + Send + 'static,
L: Listener<Conn = T2> + 'static,
S: Stream<Item = Transport<T2>> + Send + Unpin + 'static,
{
let conns: Arc<Mutex<HashMap<usize, Conn>>> = Arc::new(Mutex::new(HashMap::new()));
@ -40,8 +41,21 @@ impl RelayServer {
let conns_2 = Arc::clone(&conns);
debug!("Spawning response broadcast task");
let mut broadcast = session.broadcast.take().unwrap();
let (shutdown_broadcast_tx, mut shutdown_broadcast_rx) = mpsc::channel::<()>(1);
let broadcast_task = tokio::spawn(async move {
while let Some(res) = broadcast.recv().await {
loop {
let res = tokio::select! {
maybe_res = broadcast.recv() => {
match maybe_res {
Some(res) => res,
None => break,
}
}
_ = shutdown_broadcast_rx.recv() => {
break;
}
};
// Search for all connections with a tenant that matches the response's tenant
for conn in conns_2.lock().await.values_mut() {
if conn.state.lock().await.tenant.as_deref() == Some(res.tenant.as_str()) {
@ -66,8 +80,21 @@ impl RelayServer {
// Spawn task to send to the server requests from connections
debug!("Spawning request forwarding task");
let (req_tx, mut req_rx) = mpsc::channel::<Request>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (shutdown_forward_tx, mut shutdown_forward_rx) = mpsc::channel::<()>(1);
let forward_task = tokio::spawn(async move {
while let Some(req) = req_rx.recv().await {
loop {
let req = tokio::select! {
maybe_req = req_rx.recv() => {
match maybe_req {
Some(req) => req,
None => break,
}
}
_ = shutdown_forward_rx.recv() => {
break;
}
};
debug!(
"Forwarding request of type{} {} to server",
if req.payload.len() > 1 { "s" } else { "" },
@ -85,28 +112,29 @@ impl RelayServer {
let accept_task = tokio::spawn(async move {
let inner = async move {
loop {
match listener.accept().await {
Ok(stream) => {
match stream.next().await {
Some(transport) => {
let result = Conn::initialize(
stream,
transport,
req_tx.clone(),
tracker.as_ref().map(Arc::clone),
)
.await;
match result {
Ok(conn) => conns_2.lock().await.insert(conn.id(), conn),
Ok(conn) => {
conns_2.lock().await.insert(conn.id(), conn);
}
Err(x) => {
error!("Failed to initialize connection: {}", x);
continue;
}
};
}
Err(x) => {
debug!("Listener has closed: {}", x);
None => {
info!("Listener shutting down");
break;
}
}
};
}
};
@ -119,6 +147,11 @@ impl RelayServer {
},
None => inner.await,
}
// Doesn't matter if we send or drop these as long as they persist until this
// task is completed, so just drop
drop(shutdown_broadcast_tx);
drop(shutdown_forward_tx);
});
Ok(Self {
@ -154,7 +187,7 @@ struct Conn {
id: usize,
req_task: JoinHandle<()>,
res_task: JoinHandle<()>,
cleanup_task: JoinHandle<()>,
_cleanup_task: JoinHandle<()>,
res_tx: mpsc::Sender<Response>,
state: Arc<Mutex<ConnState>>,
}
@ -168,7 +201,7 @@ struct ConnState {
impl Conn {
pub async fn initialize<T>(
stream: T,
transport: Transport<T>,
req_tx: mpsc::Sender<Request>,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> io::Result<Self>
@ -179,11 +212,6 @@ impl Conn {
// is not guaranteed to have an identifiable string
let id: usize = rand::random();
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = Transport::from_handshake(stream, None).await.map_err(|x| {
error!("<Conn @ {}> Failed handshake: {}", id, x);
io::Error::new(io::ErrorKind::Other, x)
})?;
let (t_read, t_write) = transport.into_split();
// Used to alert our response task of the connection's tenant name
@ -220,7 +248,7 @@ impl Conn {
let _ = req_task_tx.send(());
});
let cleanup_task = tokio::spawn(async move {
let _cleanup_task = tokio::spawn(async move {
let _ = tokio::join!(req_task_rx, res_task_rx);
if let Some(ct) = ct.as_ref() {
@ -233,7 +261,7 @@ impl Conn {
id,
req_task,
res_task,
cleanup_task,
_cleanup_task,
res_tx,
state,
})
@ -384,49 +412,130 @@ async fn handle_conn_outgoing<T>(
#[cfg(test)]
mod tests {
use super::*;
use crate::net::InmemoryStream;
use std::{pin::Pin, time::Duration};
#[test]
fn wait_should_return_ok_when_all_inner_tasks_complete() {
todo!();
fn make_session() -> (Transport<InmemoryStream>, Session<InmemoryStream>) {
let (t1, t2) = Transport::make_pair();
(t1, Session::initialize(t2).unwrap())
}
#[test]
fn wait_should_return_error_when_server_aborted() {
todo!();
fn make_transport_stream() -> (
mpsc::Sender<Transport<InmemoryStream>>,
Pin<Box<dyn Stream<Item = Transport<InmemoryStream>> + Send>>,
) {
let (tx, rx) = mpsc::channel::<Transport<InmemoryStream>>(1);
let stream = futures::stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(move |transport| (transport, rx))
});
(tx, Box::pin(stream))
}
#[test]
fn abort_should_abort_inner_tasks_and_all_connections() {
todo!();
}
#[tokio::test]
async fn wait_should_return_ok_when_all_inner_tasks_complete() {
let (transport, session) = make_session();
let (tx, stream) = make_transport_stream();
let server = RelayServer::initialize(session, stream, None).unwrap();
#[test]
fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
todo!();
}
// Conclude all server tasks by closing out the listener & session
drop(transport);
drop(tx);
#[test]
fn server_shutdown_should_abort_all_connections() {
todo!();
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
#[test]
fn server_should_forward_connection_requests_to_session() {
todo!();
#[tokio::test]
async fn wait_should_return_error_when_server_aborted() {
let (_transport, session) = make_session();
let (_tx, stream) = make_transport_stream();
let server = RelayServer::initialize(session, stream, None).unwrap();
server.abort().await;
match server.wait().await {
Err(x) if x.is_cancelled() => {}
x => panic!("Unexpected result: {:?}", x),
}
}
#[test]
fn server_should_forward_session_responses_to_connection_with_matching_tenant() {
todo!();
#[tokio::test]
async fn server_should_forward_requests_using_session() {
let (mut transport, session) = make_session();
let (tx, stream) = make_transport_stream();
let _server = RelayServer::initialize(session, stream, None).unwrap();
// Send over a "connection"
let (mut t1, t2) = Transport::make_pair();
tx.send(t2).await.unwrap();
// Send a request
let req = Request::new("test-tenant", vec![RequestData::SystemInfo {}]);
t1.send(req.clone()).await.unwrap();
// Verify the request is forwarded out via session
let outbound_req = transport.receive().await.unwrap().unwrap();
assert_eq!(req, outbound_req);
}
#[test]
fn connection_abort_should_abort_inner_tasks() {
todo!();
#[tokio::test]
async fn server_should_send_back_response_with_tenant_matching_connection() {
let (mut transport, session) = make_session();
let (tx, stream) = make_transport_stream();
let _server = RelayServer::initialize(session, stream, None).unwrap();
// Send over a "connection"
let (mut t1, t2) = Transport::make_pair();
tx.send(t2).await.unwrap();
// Send over a second "connection"
let (mut t2, t3) = Transport::make_pair();
tx.send(t3).await.unwrap();
// Send a request to mark the tenant of the first connection
t1.send(Request::new(
"test-tenant-1",
vec![RequestData::SystemInfo {}],
))
.await
.unwrap();
// Send a request to mark the tenant of the second connection
t2.send(Request::new(
"test-tenant-2",
vec![RequestData::SystemInfo {}],
))
.await
.unwrap();
// Clear out the transport channel (outbound of session)
// NOTE: Because our test stream uses a buffer size of 1, we have to clear out the
// outbound data from the earlier requests before we can send back a response
let _ = transport.receive::<Request>().await.unwrap().unwrap();
let _ = transport.receive::<Request>().await.unwrap().unwrap();
// Send a response back to a singular connection based on the tenant
let res = Response::new("test-tenant-2", None, vec![ResponseData::Ok]);
transport.send(res.clone()).await.unwrap();
// Verify that response is only received by a singular connection
let inbound_res = t2.receive().await.unwrap().unwrap();
assert_eq!(res, inbound_res);
let no_inbound = tokio::select! {
_ = t1.receive::<Response>() => {false}
_ = tokio::time::sleep(Duration::from_millis(50)) => {true}
};
assert!(no_inbound, "Unexpectedly got response for wrong connection");
}
#[test]
fn connection_abort_should_send_process_kill_requests_through_session() {
todo!();
#[tokio::test]
async fn server_should_shutdown_if_no_connections_after_shutdown_duration() {
let (_transport, session) = make_session();
let (_tx, stream) = make_transport_stream();
let server =
RelayServer::initialize(session, stream, Some(Duration::from_millis(50))).unwrap();
let result = server.wait().await;
assert!(result.is_ok(), "Unexpected result: {:?}", result);
}
}

@ -5,7 +5,9 @@ use crate::{
utils,
};
use derive_more::{Display, Error, From};
use distant_core::{RelayServer, Session, SessionInfo, SessionInfoFile};
use distant_core::{
RelayServer, Session, SessionInfo, SessionInfoFile, TransportListener, TransportListenerCtx,
};
use fork::{daemon, Fork};
use log::*;
use std::{path::Path, string::FromUtf8Error};
@ -154,7 +156,16 @@ async fn socket_loop(
debug!("Binding to unix socket: {:?}", socket_path.as_ref());
let listener = tokio::net::UnixListener::bind(socket_path)?;
let server = RelayServer::initialize(session, listener, shutdown_after)?;
let stream = TransportListener::initialize(
listener,
TransportListenerCtx {
auth_key: None,
timeout: Some(duration),
},
)
.into_stream();
let server = RelayServer::initialize(session, Box::pin(stream), shutdown_after)?;
server
.wait()
.await

@ -3,9 +3,10 @@ use crate::{
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
};
use derive_more::{Display, Error, From};
use distant_core::DistantServer;
use distant_core::{DistantServer, DistantServerOptions, SecretKey, UnprotectedToHexKey};
use fork::{daemon, Fork};
use log::*;
use std::sync::Arc;
use tokio::{io, task::JoinError};
#[derive(Debug, Display, Error, From)]
@ -62,19 +63,23 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R
}
// Bind & start our server
let server = DistantServer::bind(
let auth_key = Arc::new(SecretKey::default());
let (server, port) = DistantServer::bind(
addr,
cmd.port,
shutdown_after,
cmd.max_msg_capacity as usize,
Some(Arc::clone(&auth_key)),
DistantServerOptions {
shutdown_after,
max_msg_capacity: cmd.max_msg_capacity as usize,
},
)
.await?;
// Print information about port, key, etc.
println!(
"DISTANT DATA -- {} {}",
server.port(),
server.to_unprotected_hex_auth_key()
port,
auth_key.unprotected_to_hex_key()
);
// For the child, we want to fully disconnect it from pipes, which we do now

@ -70,9 +70,42 @@ fn should_support_including_a_canonicalized_path(mut action_cmd: Command) {
let file = temp.child("file");
file.touch().unwrap();
let link = temp.child("link");
link.symlink_to_file(file.path()).unwrap();
// distant action metadata --canonicalize {path}
action_cmd
.args(&["metadata", "--canonicalize", file.to_str().unwrap()])
.args(&["metadata", "--canonicalize", link.to_str().unwrap()])
.assert()
.success()
.stdout(regex_pred(&format!(
concat!(
"Canonicalized Path: {:?}\n",
"Type: symlink\n",
"Len: .*\n",
"Readonly: false\n",
"Created: .*\n",
"Last Accessed: .*\n",
"Last Modified: .*\n",
),
file.path().canonicalize().unwrap()
)))
.stderr("");
}
#[rstest]
fn should_support_resolving_file_type_of_symlink(mut action_cmd: Command) {
let temp = assert_fs::TempDir::new().unwrap();
let file = temp.child("file");
file.touch().unwrap();
let link = temp.child("link");
link.symlink_to_file(file.path()).unwrap();
// distant action metadata --canonicalize {path}
action_cmd
.args(&["metadata", "--resolve-file-type", link.to_str().unwrap()])
.assert()
.success()
.stdout(regex_pred(concat!(
@ -115,6 +148,7 @@ fn should_support_json_metadata_for_file(mut action_cmd: Command) {
payload: vec![RequestData::Metadata {
path: file.to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
};
@ -156,6 +190,7 @@ fn should_support_json_metadata_for_directory(mut action_cmd: Command) {
payload: vec![RequestData::Metadata {
path: dir.to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
};
@ -191,12 +226,57 @@ fn should_support_json_metadata_for_including_a_canonicalized_path(mut action_cm
let file = temp.child("file");
file.touch().unwrap();
let link = temp.child("link");
link.symlink_to_file(file.path()).unwrap();
let req = Request {
id: rand::random(),
tenant: random_tenant(),
payload: vec![RequestData::Metadata {
path: file.to_path_buf(),
path: link.to_path_buf(),
canonicalize: true,
resolve_file_type: false,
}],
};
// distant action --format json --interactive
let cmd = action_cmd
.args(&["--format", "json"])
.arg("--interactive")
.write_stdin(format!("{}\n", serde_json::to_string(&req).unwrap()))
.assert()
.success()
.stderr("");
let res: Response = serde_json::from_slice(&cmd.get_output().stdout).unwrap();
match &res.payload[0] {
ResponseData::Metadata {
canonicalized_path: Some(path),
file_type: FileType::Symlink,
readonly: false,
..
} => assert_eq!(path, &file.path().canonicalize().unwrap()),
x => panic!("Unexpected response: {:?}", x),
}
}
#[rstest]
fn should_support_json_metadata_for_resolving_file_type_of_symlink(mut action_cmd: Command) {
let temp = assert_fs::TempDir::new().unwrap();
let file = temp.child("file");
file.touch().unwrap();
let link = temp.child("link");
link.symlink_to_file(file.path()).unwrap();
let req = Request {
id: rand::random(),
tenant: random_tenant(),
payload: vec![RequestData::Metadata {
path: link.to_path_buf(),
canonicalize: true,
resolve_file_type: true,
}],
};
@ -214,9 +294,7 @@ fn should_support_json_metadata_for_including_a_canonicalized_path(mut action_cm
matches!(
res.payload[0],
ResponseData::Metadata {
canonicalized_path: Some(_),
file_type: FileType::File,
readonly: false,
..
},
),
@ -238,6 +316,7 @@ fn should_support_json_output_for_error(mut action_cmd: Command) {
payload: vec![RequestData::Metadata {
path: file.to_path_buf(),
canonicalize: false,
resolve_file_type: false,
}],
};

@ -2,7 +2,7 @@ use crate::cli::utils;
use assert_cmd::Command;
use distant_core::*;
use rstest::*;
use std::{ffi::OsStr, net::SocketAddr, thread};
use std::{ffi::OsStr, net::SocketAddr, sync::Arc, thread};
use tokio::{runtime::Runtime, sync::mpsc};
const LOG_PATH: &'static str = "/tmp/test.distant.server.log";
@ -30,14 +30,18 @@ impl DistantServerCtx {
Ok(rt) => {
rt.block_on(async move {
let logger = utils::init_logging(LOG_PATH);
let server = DistantServer::bind(ip_addr, "0".parse().unwrap(), None, 100)
.await
.unwrap();
let opts = DistantServerOptions {
shutdown_after: None,
max_msg_capacity: 100,
};
let auth_key = Arc::new(SecretKey::default());
let auth_key_hex_str = auth_key.unprotected_to_hex_key();
let (_server, port) =
DistantServer::bind(ip_addr, "0".parse().unwrap(), Some(auth_key), opts)
.await
.unwrap();
started_tx
.send(Ok((server.port(), server.to_unprotected_hex_auth_key())))
.await
.unwrap();
started_tx.send(Ok((port, auth_key_hex_str))).await.unwrap();
let _ = done_rx.recv().await;
logger.flush();

Loading…
Cancel
Save