Refactor codebase to be more testable & add some initial tests

pull/38/head
Chip Senkbeil 3 years ago committed by Chip Senkbeil
parent 1ca3cd7859
commit ba6ebcfcb8
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

125
Cargo.lock generated

@ -20,6 +20,20 @@ dependencies = [
"winapi",
]
[[package]]
name = "assert_cmd"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54f002ce7d0c5e809ebb02be78fd503aeed4a511fd0fcaff6e6914cbdabbfa33"
dependencies = [
"bstr",
"doc-comment",
"predicates",
"predicates-core",
"predicates-tree",
"wait-timeout",
]
[[package]]
name = "atty"
version = "0.2.14"
@ -52,6 +66,17 @@ dependencies = [
"generic-array",
]
[[package]]
name = "bstr"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279"
dependencies = [
"lazy_static",
"memchr",
"regex-automata",
]
[[package]]
name = "bumpalo"
version = "3.7.0"
@ -168,6 +193,12 @@ dependencies = [
"syn",
]
[[package]]
name = "difflib"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8"
[[package]]
name = "digest"
version = "0.9.0"
@ -181,6 +212,7 @@ dependencies = [
name = "distant"
version = "0.13.0"
dependencies = [
"assert_cmd",
"bytes",
"derive_more",
"flexi_logger",
@ -197,6 +229,7 @@ dependencies = [
"serde_json",
"structopt",
"strum",
"tempfile",
"tokio",
"tokio-stream",
"tokio-util",
@ -204,6 +237,12 @@ dependencies = [
"whoami",
]
[[package]]
name = "doc-comment"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
[[package]]
name = "ecdsa"
version = "0.12.3"
@ -216,6 +255,12 @@ dependencies = [
"signature",
]
[[package]]
name = "either"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
[[package]]
name = "elliptic-curve"
version = "0.10.5"
@ -448,6 +493,15 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "itertools"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69ddb889f9d0d08a67338271fa9b62996bc788c7796a5c18cf057420aaed5eaf"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "0.4.7"
@ -648,6 +702,33 @@ version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "predicates"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c143348f141cc87aab5b950021bac6145d0e5ae754b0591de23244cee42c9308"
dependencies = [
"difflib",
"itertools",
"predicates-core",
]
[[package]]
name = "predicates-core"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57e35a3326b75e49aa85f5dc6ec15b41108cf5aee58eabb1f274dd18b73c2451"
[[package]]
name = "predicates-tree"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7dd0fd014130206c9352efbdc92be592751b2b9274dff685348341082c6ea3d"
dependencies = [
"predicates-core",
"treeline",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
@ -762,12 +843,27 @@ dependencies = [
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
[[package]]
name = "regex-syntax"
version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "ryu"
version = "1.0.5"
@ -951,6 +1047,20 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "tempfile"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22"
dependencies = [
"cfg-if",
"libc",
"rand",
"redox_syscall",
"remove_dir_all",
"winapi",
]
[[package]]
name = "textwrap"
version = "0.11.0"
@ -1047,6 +1157,12 @@ dependencies = [
"tokio",
]
[[package]]
name = "treeline"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7f741b240f1a48843f9b8e0444fb55fb2a4ff67293b50a9179dfd5ea67f8d41"
[[package]]
name = "typenum"
version = "1.13.0"
@ -1083,6 +1199,15 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe"
[[package]]
name = "wait-timeout"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6"
dependencies = [
"libc",
]
[[package]]
name = "walkdir"
version = "2.3.2"

@ -39,3 +39,7 @@ fork = "0.1.18"
lazy_static = "1.4.0"
structopt = "0.3.22"
whoami = "1.1.2"
[dev-dependencies]
assert_cmd = "2.0.0"
tempfile = "3.2.0"

@ -1,38 +1,41 @@
use crate::core::net::TransportError;
use crate::core::{client::RemoteProcessError, net::TransportError};
/// Exit codes following https://www.freebsd.org/cgi/man.cgi?query=sysexits&sektion=3
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ExitCode {
/// EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI
Usage = 64,
Usage,
/// EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data
/// is bad
DataErr = 65,
DataErr,
/// EX_NOINPUT (66) - being used when not getting expected data from launch
NoInput = 66,
NoInput,
/// EX_NOHOST (68) - being used when failed to resolve a host
NoHost = 68,
NoHost,
/// EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem
Unavailable = 69,
Unavailable,
/// EX_SOFTWARE (70) - being used for internal errors that can occur like joining a task
Software = 70,
Software,
/// EX_OSERR (71) - being used when fork failed
OsErr = 71,
OsErr,
/// EX_IOERR (74) - being used as catchall for IO errors
IoError = 74,
IoError,
/// EX_TEMPFAIL (75) - being used when we get a timeout
TempFail = 75,
TempFail,
/// EX_PROTOCOL (76) - being used as catchall for transport errors
Protocol = 76,
Protocol,
/// Custom exit code to pass back verbatim
Custom(i32),
}
/// Represents an error that can be converted into an exit code
@ -40,7 +43,19 @@ pub trait ExitCodeError: std::error::Error {
fn to_exit_code(&self) -> ExitCode;
fn to_i32(&self) -> i32 {
self.to_exit_code() as i32
match self.to_exit_code() {
ExitCode::Usage => 64,
ExitCode::DataErr => 65,
ExitCode::NoInput => 66,
ExitCode::NoHost => 68,
ExitCode::Unavailable => 69,
ExitCode::Software => 70,
ExitCode::OsErr => 71,
ExitCode::IoError => 74,
ExitCode::TempFail => 75,
ExitCode::Protocol => 76,
ExitCode::Custom(x) => x,
}
}
}
@ -68,6 +83,19 @@ impl ExitCodeError for TransportError {
}
}
impl ExitCodeError for RemoteProcessError {
fn to_exit_code(&self) -> ExitCode {
match self {
Self::BadResponse => ExitCode::DataErr,
Self::ChannelDead => ExitCode::Unavailable,
Self::Overloaded => ExitCode::Software,
Self::TransportError(x) => x.to_exit_code(),
Self::UnexpectedEof => ExitCode::IoError,
Self::WaitFailed(_) => ExitCode::Software,
}
}
}
impl<T: ExitCodeError + 'static> From<T> for Box<dyn ExitCodeError> {
fn from(x: T) -> Self {
Box::new(x)

@ -0,0 +1,112 @@
use crate::{
cli,
core::{
client::{
RemoteLspStderr, RemoteLspStdin, RemoteLspStdout, RemoteStderr, RemoteStdin,
RemoteStdout,
},
constants::MAX_PIPE_CHUNK_SIZE,
},
};
use std::{
io::{self, Write},
thread,
};
use tokio::task::{JoinError, JoinHandle};
/// Represents a link between a remote process' stdin/stdout/stderr and this process'
/// stdin/stdout/stderr
pub struct RemoteProcessLink {
_stdin_thread: thread::JoinHandle<()>,
stdin_task: JoinHandle<io::Result<()>>,
stdout_task: JoinHandle<io::Result<()>>,
stderr_task: JoinHandle<io::Result<()>>,
}
macro_rules! from_pipes {
($stdin:expr, $stdout:expr, $stderr:expr) => {{
let (stdin_thread, mut stdin_rx) = cli::stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE);
let stdin_task = tokio::spawn(async move {
loop {
if let Some(input) = stdin_rx.recv().await {
if let Err(x) = $stdin.write(input.as_str()).await {
break Err(x);
}
} else {
break Ok(());
}
}
});
let stdout_task = tokio::spawn(async move {
let handle = io::stdout();
loop {
match $stdout.read().await {
Ok(output) => {
let mut out = handle.lock();
out.write_all(output.as_bytes())?;
out.flush()?;
}
Err(x) => break Err(x),
}
}
});
let stderr_task = tokio::spawn(async move {
let handle = io::stderr();
loop {
match $stderr.read().await {
Ok(output) => {
let mut out = handle.lock();
out.write_all(output.as_bytes())?;
out.flush()?;
}
Err(x) => break Err(x),
}
}
});
RemoteProcessLink {
_stdin_thread: stdin_thread,
stdin_task,
stdout_task,
stderr_task,
}
}};
}
impl RemoteProcessLink {
/// Creates a new process link from the pipes of a remote process
pub fn from_remote_pipes(
mut stdin: RemoteStdin,
mut stdout: RemoteStdout,
mut stderr: RemoteStderr,
) -> Self {
from_pipes!(stdin, stdout, stderr)
}
/// Creates a new process link from the pipes of a remote LSP server process
pub fn from_remote_lsp_pipes(
mut stdin: RemoteLspStdin,
mut stdout: RemoteLspStdout,
mut stderr: RemoteLspStderr,
) -> Self {
from_pipes!(stdin, stdout, stderr)
}
/// Shuts down the link, aborting any running tasks, and swallowing join errors
pub async fn shutdown(self) {
self.abort();
let _ = self.wait().await;
}
/// Waits for the stdin, stdout, and stderr tasks to complete
pub async fn wait(self) -> Result<(), JoinError> {
tokio::try_join!(self.stdin_task, self.stdout_task, self.stderr_task).map(|_| ())
}
/// Aborts the link by aborting tasks processing stdin, stdout, and stderr
pub fn abort(&self) {
self.stdin_task.abort();
self.stdout_task.abort();
self.stderr_task.abort();
}
}

@ -1,10 +1,13 @@
mod buf;
mod exit;
mod link;
mod opt;
mod output;
mod session;
mod stdin;
mod subcommand;
pub use exit::{ExitCode, ExitCodeError};
pub use opt::*;
pub use output::ResponseOut;
pub use session::CliSession;

@ -1,57 +1,84 @@
use crate::{
cli::{buf::StringBuf, Format, ResponseOut},
cli::{buf::StringBuf, stdin, Format, ResponseOut},
core::{
client::Session,
constants::MAX_PIPE_CHUNK_SIZE,
data::{Request, Response},
data::{Request, RequestData, Response},
net::DataStream,
},
};
use log::*;
use std::{
io::{self, BufReader, Read},
sync::Arc,
thread,
};
use tokio::sync::{mpsc, watch};
use std::{io, thread};
use structopt::StructOpt;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
/// Represents a wrapper around a session that provides CLI functionality such as reading from
/// stdin and piping results back out to stdout
pub struct CliSession<T>
where
T: DataStream,
{
inner: Session<T>,
pub struct CliSession {
stdin_thread: thread::JoinHandle<()>,
req_task: JoinHandle<()>,
res_task: JoinHandle<io::Result<()>>,
}
impl<T> CliSession<T>
where
T: DataStream,
{
pub fn new(inner: Session<T>) -> Self {
Self { inner }
impl CliSession {
pub fn new<T>(tenant: String, mut session: Session<T>, format: Format) -> Self
where
T: DataStream + 'static,
{
let (stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE);
let (exit_tx, exit_rx) = mpsc::channel(1);
let stream = session.to_response_broadcast_stream();
let res_task =
tokio::spawn(async move { process_incoming_responses(stream, format, exit_rx).await });
let map_line = move |line: &str| match format {
Format::Json => serde_json::from_str(&line)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x)),
Format::Shell => {
let data = RequestData::from_iter_safe(
std::iter::once("distant")
.chain(line.trim().split(' ').filter(|s| !s.trim().is_empty())),
)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x));
data.map(|x| Request::new(tenant.to_string(), vec![x]))
}
};
let req_task = tokio::spawn(async move {
process_outgoing_requests(session, stdin_rx, exit_tx, format, map_line).await
});
Self {
stdin_thread,
req_task,
res_task,
}
}
}
// TODO TODO TODO:
//
// 1. Change watch to broadcast if going to use in both loops, otherwise just make
// it an mpsc otherwise
// 2. Need to provide outgoing requests function with logic from inner.rs to create a request
// based on the format (json or shell), where json uses serde_json::from_str and shell
// uses Request::new(tenant.as_str(), vec![RequestData::from_iter_safe(...)])
// 3. Need to add a wait method to block on the running tasks
// 4. Need to add an abort method to abort the tasks
// 5. Is there any way to deal with the blocking thread for stdin to kill it? This isn't a big
// deal as the shutdown would only be happening on client termination anyway, but still...
/// Wait for the cli session to terminate
pub async fn wait(self) -> io::Result<()> {
match tokio::try_join!(self.req_task, self.res_task) {
Ok((_, res)) => res,
Err(x) => Err(io::Error::new(io::ErrorKind::BrokenPipe, x)),
}
}
/// Aborts the cli session forcing its task handlers to abort underneath, which means that a
/// call to `wait` will return an error
pub async fn abort(&self) {
self.req_task.abort();
self.res_task.abort();
}
}
/// Helper function that loops, processing incoming responses not tied to a request to be sent out
/// over stdout/stderr
async fn process_incoming_responses(
mut stream: BroadcastStream<Response>,
format: Format,
mut exit: watch::Receiver<bool>,
mut exit: mpsc::Receiver<()>,
) -> io::Result<()> {
loop {
tokio::select! {
@ -62,7 +89,7 @@ async fn process_incoming_responses(
None => return Ok(()),
}
}
_ = exit.changed() => {
_ = exit.recv() => {
return Ok(());
}
}
@ -74,6 +101,7 @@ async fn process_incoming_responses(
async fn process_outgoing_requests<T, F>(
mut session: Session<T>,
mut stdin_rx: mpsc::Receiver<String>,
exit_tx: mpsc::Sender<()>,
format: Format,
map_line: F,
) where
@ -90,9 +118,16 @@ async fn process_outgoing_requests<T, F>(
// For each complete line, parse into a request
if let Some(lines) = lines {
for line in lines.lines() {
for line in lines.lines().map(|line| line.trim()) {
trace!("Processing line: {:?}", line);
if line.trim().is_empty() {
if line.is_empty() {
continue;
} else if line == "exit" {
debug!("Got exit request, so closing cli session");
stdin_rx.close();
if let Err(_) = exit_tx.send(()).await {
error!("Failed to close cli session");
}
continue;
}
@ -114,42 +149,3 @@ async fn process_outgoing_requests<T, F>(
}
}
}
/// Creates a new thread that performs stdin reads in a blocking fashion, returning
/// a handle to the thread and a receiver that will be sent input as it becomes available
fn spawn_stdin_reader() -> (thread::JoinHandle<()>, mpsc::Receiver<String>) {
let (tx, rx) = mpsc::channel(1);
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
// pass the results to a separate async handler to forward to the remote process
let handle = thread::spawn(move || {
let mut stdin = BufReader::new(io::stdin());
// Maximum chunk that we expect to read at any one time
let mut buf = [0; MAX_PIPE_CHUNK_SIZE];
loop {
match stdin.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
match String::from_utf8(buf[..n].to_vec()) {
Ok(text) => {
if let Err(x) = tx.blocking_send(text) {
error!(
"Failed to pass along stdin to be sent to remote process: {}",
x
);
}
}
Err(x) => {
error!("Input over stdin is invalid: {}", x);
}
}
thread::yield_now();
}
}
}
});
(handle, rx)
}

@ -0,0 +1,43 @@
use log::error;
use std::{
io::{self, BufReader, Read},
thread,
};
use tokio::sync::mpsc;
/// Creates a new thread that performs stdin reads in a blocking fashion, returning
/// a handle to the thread and a receiver that will be sent input as it becomes available
pub fn spawn_channel(buffer: usize) -> (thread::JoinHandle<()>, mpsc::Receiver<String>) {
let (tx, rx) = mpsc::channel(1);
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
// pass the results to a separate async handler to forward to the remote process
let handle = thread::spawn(move || {
let mut stdin = BufReader::new(io::stdin());
// Maximum chunk that we expect to read at any one time
let mut buf = vec![0; buffer];
loop {
match stdin.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
match String::from_utf8(buf[..n].to_vec()) {
Ok(text) => {
if let Err(x) = tx.blocking_send(text) {
error!("Stdin channel closed: {}", x);
break;
}
}
Err(x) => {
error!("Input over stdin is invalid: {}", x);
}
}
thread::yield_now();
}
}
}
});
(handle, rx)
}

@ -0,0 +1,186 @@
use crate::{
cli::{
link::RemoteProcessLink,
opt::{ActionSubcommand, CommonOpt, SessionInput},
CliSession, ExitCode, ExitCodeError, ResponseOut,
},
core::{
client::{
self, LspData, RemoteProcess, RemoteProcessError, Session, SessionInfo, SessionInfoFile,
},
data::{Request, RequestData},
net::{DataStream, TransportError},
},
};
use derive_more::{Display, Error, From};
use tokio::{io, time::Duration};
#[derive(Debug, Display, Error, From)]
pub enum Error {
#[display(fmt = "Process failed with exit code: {}", _0)]
BadProcessExit(#[error(not(source))] i32),
IoError(io::Error),
#[display(fmt = "Non-interactive but no operation supplied")]
MissingOperation,
RemoteProcessError(RemoteProcessError),
TransportError(TransportError),
}
impl ExitCodeError for Error {
fn to_exit_code(&self) -> ExitCode {
match self {
Self::BadProcessExit(x) => ExitCode::Custom(*x),
Self::IoError(x) => x.to_exit_code(),
Self::MissingOperation => ExitCode::Usage,
Self::RemoteProcessError(x) => x.to_exit_code(),
Self::TransportError(x) => x.to_exit_code(),
}
}
}
pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt).await })
}
async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let timeout = opt.to_timeout_duration();
match cmd.session {
SessionInput::Environment => {
start(
cmd,
Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?,
timeout,
None,
)
.await
}
SessionInput::File => {
let path = cmd.session_data.session_file.clone();
start(
cmd,
Session::tcp_connect_timeout(
SessionInfoFile::load_from(path).await?.into(),
timeout,
)
.await?,
timeout,
None,
)
.await
}
SessionInput::Pipe => {
start(
cmd,
Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?,
timeout,
None,
)
.await
}
SessionInput::Lsp => {
let mut data =
LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?;
let info = data.take_session_info().map_err(io::Error::from)?;
start(
cmd,
Session::tcp_connect_timeout(info, timeout).await?,
timeout,
Some(data),
)
.await
}
#[cfg(unix)]
SessionInput::Socket => {
let path = cmd.session_data.session_socket.clone();
start(
cmd,
Session::unix_connect_timeout(path, None, timeout).await?,
timeout,
None,
)
.await
}
}
}
async fn start<T>(
cmd: ActionSubcommand,
mut session: Session<T>,
timeout: Duration,
lsp_data: Option<LspData>,
) -> Result<(), Error>
where
T: DataStream + 'static,
{
// TODO: Because lsp is being handled in a separate action, we should fail if we get
// a session type of lsp for a regular action
match (cmd.interactive, cmd.operation) {
// ProcRun request is specially handled and we ignore interactive as
// the stdin will be used for sending ProcStdin to remote process
(_, Some(RequestData::ProcRun { cmd, args })) => {
let mut proc = RemoteProcess::spawn(client::new_tenant(), session, cmd, args).await?;
// If we also parsed an LSP's initialize request for its session, we want to forward
// it along in the case of a process call
if let Some(data) = lsp_data {
proc.stdin.as_mut().unwrap().write(data.to_string()).await?;
}
// Now, map the remote process' stdin/stdout/stderr to our own process
let link = RemoteProcessLink::from_remote_pipes(
proc.stdin.take().unwrap(),
proc.stdout.take().unwrap(),
proc.stderr.take().unwrap(),
);
let (success, exit_code) = proc.wait().await?;
// Shut down our link
link.shutdown().await;
if !success {
if let Some(code) = exit_code {
return Err(Error::BadProcessExit(code));
} else {
return Err(Error::BadProcessExit(1));
}
}
Ok(())
}
// All other requests without interactive are oneoffs
(false, Some(data)) => {
let res = session
.send_timeout(Request::new(client::new_tenant(), vec![data]), timeout)
.await?;
ResponseOut::new(cmd.format, res)?.print();
Ok(())
}
// Interactive mode will send an optional first request and then continue
// to read stdin to send more
(true, maybe_req) => {
// Send our first request if provided
if let Some(data) = maybe_req {
let res = session
.send_timeout(Request::new(client::new_tenant(), vec![data]), timeout)
.await?;
ResponseOut::new(cmd.format, res)?.print();
}
// Enter into CLI session where we receive requests on stdin and send out
// over stdout/stderr
let cli_session = CliSession::new(client::new_tenant(), session, cmd.format);
cli_session.wait().await?;
Ok(())
}
// Not interactive and no operation given
(false, None) => Err(Error::MissingOperation),
}
}

@ -1,182 +0,0 @@
use crate::{
cli::opt::Format,
core::{
constants::MAX_PIPE_CHUNK_SIZE,
data::{Error, Request, RequestData, Response, ResponseData},
net::{Client, DataStream},
utils::StringBuf,
},
};
use derive_more::IsVariant;
use log::*;
use std::marker::Unpin;
use structopt::StructOpt;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
sync::{
mpsc,
oneshot::{self, error::TryRecvError},
},
};
use tokio_stream::StreamExt;
#[derive(Copy, Clone, PartialEq, Eq, IsVariant)]
pub enum LoopConfig {
Json,
Proc { id: usize },
Shell,
}
impl From<LoopConfig> for Format {
fn from(config: LoopConfig) -> Self {
match config {
LoopConfig::Json => Self::Json,
LoopConfig::Proc { .. } | LoopConfig::Shell => Self::Shell,
}
}
}
/// Starts a new action loop that processes requests and receives responses
///
/// id represents the id of a remote process
pub async fn interactive_loop<T>(
mut client: Client<T>,
tenant: String,
config: LoopConfig,
) -> io::Result<()>
where
T: AsyncRead + AsyncWrite + DataStream + Unpin + 'static,
{
let mut stream = client.to_response_broadcast_stream();
// Create a channel that can report when we should stop the loop based on a received request
let (tx_stop, mut rx_stop) = oneshot::channel::<()>();
// We also want to spawn a task to handle sending stdin to the remote process
let mut rx = spawn_stdin_reader();
tokio::spawn(async move {
let mut buf = StringBuf::new();
while let Some(data) = rx.recv().await {
match config {
// Special exit condition for interactive format
_ if buf.trim() == "exit" => {
if let Err(_) = tx_stop.send(()) {
error!("Failed to close interactive loop!");
}
break;
}
// For json format, all stdin is treated as individual requests
LoopConfig::Json => {
buf.push_str(&data);
let (lines, new_buf) = buf.into_full_lines();
buf = new_buf;
// For each complete line, parse it as json and
if let Some(lines) = lines {
for data in lines.lines() {
debug!("Client sending request: {:?}", data);
let result = serde_json::from_str(&data)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x));
match result {
Ok(req) => match client.send(req).await {
Ok(res) => match format_response(Format::Json, res) {
Ok(out) => out.print(),
Err(x) => error!("Failed to format response: {}", x),
},
Err(x) => {
error!("Failed to send request: {}", x)
}
},
Err(x) => {
error!("Failed to serialize request ('{}'): {}", data, x);
}
}
}
}
}
// For interactive shell format, parse stdin as individual commands
LoopConfig::Shell => {
buf.push_str(&data);
let (lines, new_buf) = buf.into_full_lines();
buf = new_buf;
if let Some(lines) = lines {
for data in lines.lines() {
trace!("Shell processing line: {:?}", data);
if data.trim().is_empty() {
continue;
}
debug!("Client sending command: {:?}", data);
// NOTE: We have to stick something in as the first argument as clap/structopt
// expect the binary name as the first item in the iterator
let result = RequestData::from_iter_safe(
std::iter::once("distant")
.chain(data.trim().split(' ').filter(|s| !s.trim().is_empty())),
);
match result {
Ok(data) => {
match client
.send(Request::new(tenant.as_str(), vec![data]))
.await
{
Ok(res) => match format_response(Format::Shell, res) {
Ok(out) => out.print(),
Err(x) => error!("Failed to format response: {}", x),
},
Err(x) => {
error!("Failed to send request: {}", x)
}
}
}
Err(x) => {
error!("Failed to parse command: {}", x);
}
}
}
}
}
// For non-interactive shell format, all stdin is treated as a proc's stdin
LoopConfig::Proc { id } => {
debug!("Client sending stdin: {:?}", data);
let req =
Request::new(tenant.as_str(), vec![RequestData::ProcStdin { id, data }]);
let result = client.send(req).await;
if let Err(x) = result {
error!("Failed to send stdin to remote process ({}): {}", id, x);
}
}
}
}
});
while let Err(TryRecvError::Empty) = rx_stop.try_recv() {
if let Some(res) = stream.next().await {
let res = res.map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))?;
// NOTE: If the loop is for a proxy process, we should assume that the payload
// is all-or-nothing for the done check
let done = config.is_proc() && res.payload.iter().any(|x| x.is_proc_done());
format_response(config.into(), res)?.print();
// If we aren't interactive but are just running a proc and
// we've received the end of the proc, we should exit
if done {
break;
}
// If we have nothing else in our stream, we should also exit
} else {
break;
}
}
Ok(())
}

@ -1,221 +0,0 @@
use crate::{
cli::{
opt::{ActionSubcommand, CommonOpt, Format, SessionInput},
ExitCode, ExitCodeError,
},
core::{
client::{LspData, Session, SessionInfo, SessionInfoFile},
data::{Request, RequestData, ResponseData},
net::{DataStream, TransportError},
},
};
use derive_more::{Display, Error, From};
use log::*;
use tokio::{io, time::Duration};
pub(crate) mod inner;
#[derive(Debug, Display, Error, From)]
pub enum Error {
IoError(io::Error),
TransportError(TransportError),
#[display(fmt = "Non-interactive but no operation supplied")]
MissingOperation,
}
impl ExitCodeError for Error {
fn to_exit_code(&self) -> ExitCode {
match self {
Self::IoError(x) => x.to_exit_code(),
Self::TransportError(x) => x.to_exit_code(),
Self::MissingOperation => ExitCode::Usage,
}
}
}
pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt).await })
}
async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let timeout = opt.to_timeout_duration();
match cmd.session {
SessionInput::Environment => {
start(
cmd,
Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?,
timeout,
None,
)
.await
}
SessionInput::File => {
let path = cmd.session_data.session_file.clone();
start(
cmd,
Session::tcp_connect_timeout(
SessionInfoFile::load_from(path).await?.into(),
timeout,
)
.await?,
timeout,
None,
)
.await
}
SessionInput::Pipe => {
start(
cmd,
Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?,
timeout,
None,
)
.await
}
SessionInput::Lsp => {
let mut data =
LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?;
let info = data.take_session_info().map_err(io::Error::from)?;
start(
cmd,
Session::tcp_connect_timeout(info, timeout).await?,
timeout,
Some(data),
)
.await
}
#[cfg(unix)]
SessionInput::Socket => {
let path = cmd.session_data.session_socket.clone();
start(
cmd,
Session::unix_connect_timeout(path, None, timeout).await?,
timeout,
None,
)
.await
}
}
}
async fn start<T>(
cmd: ActionSubcommand,
mut session: Session<T>,
timeout: Duration,
lsp_data: Option<LspData>,
) -> Result<(), Error>
where
T: DataStream + 'static,
{
// TODO: Because lsp is being handled in a separate action, we should fail if we get
// a session type of lsp for a regular action
match (cmd.interactive, cmd.operation) {
// ProcRun request is specially handled and we ignore interactive as
// the stdin will be used for sending ProcStdin to remote process
(_, Some(RequestData::ProcRun { cmd, args })) => {}
// All other requests without interactive are oneoffs
(false, Some(req)) => {
let res = session.send_timeout(req, timeout).await?;
}
// Interactive mode will send an optional first request and then continue
// to read stdin to send more
(true, maybe_req) => {}
// Not interactive and no operation given
(false, None) => Err(Error::MissingOperation),
}
// 1. Determine what type of engagement we're doing
// a. Oneoff connection, request, response
// b. ProcRun where we take over stdin, stdout, stderr to provide a remote
// process experience
// c. Lsp where we do the ProcRun stuff, but translate stdin before sending and
// stdout before outputting
// d. Interactive program
//
// 2. If we have a queued up operation, we need to perform it
// a. For oneoff, this is the request of the oneoff
// b. For Procrun, this is the request that starts the process
// c. For Lsp, this is the request that starts the process
// d. For interactive, this is an optional first request
//
// 3. If we are using LSP session mode, then we want to send the
// ProcStdin request after our optional queued up operation
// a. For oneoff, this doesn't make sense and we should fail
// b. For ProcRun, we do this after the ProcStart
// c. For Lsp, we do this after the ProcStart
// d. For interactive, this doesn't make sense as we only support
// JSON and shell command input, not LSP input, so this would
// fail and we should fail early
//
// ** LSP would be its own action, which means we want to abstract the logic that feeds
// into this start method such that it can also be used with lsp action
// Make up a tenant name
let tenant = utils::new_tenant();
// Special conditions for continuing to process responses
let mut is_proc_req = false;
let mut proc_id = 0;
if let Some(req) = cmd
.operation
.map(|payload| Request::new(tenant.as_str(), vec![payload]))
{
// NOTE: We know that there is a single payload entry, so it's all-or-nothing
is_proc_req = req.payload.iter().any(|x| x.is_proc_run());
debug!("Client sending request: {:?}", req);
let res = session.send_timeout(req, timeout).await?;
// Store the spawned process id for using in sending stdin (if we spawned a proc)
// NOTE: We can assume that there is a single payload entry in response to our single
// entry in our request
if let Some(ResponseData::ProcStart { id }) = res.payload.first() {
proc_id = *id;
}
inner::format_response(cmd.format, res)?.print();
// If we also parsed an LSP's initialize request for its session, we want to forward
// it along in the case of a process call
//
// TODO: Do we need to do this somewhere else to apply to all possible ways an LSP
// could be started?
if let Some(data) = lsp_data {
session
.fire_timeout(
Request::new(
tenant.as_str(),
vec![RequestData::ProcStdin {
id: proc_id,
data: data.to_string(),
}],
),
timeout,
)
.await?;
}
}
// If we are executing a process, we want to continue interacting via stdin and receiving
// results via stdout/stderr
//
// If we are interactive, we want to continue looping regardless
if is_proc_req || cmd.interactive {
let config = match cmd.format {
Format::Json => inner::LoopConfig::Json,
Format::Shell if cmd.interactive => inner::LoopConfig::Shell,
Format::Shell => inner::LoopConfig::Proc { id: proc_id },
};
inner::interactive_loop(client, tenant, config).await?;
}
Ok(())
}

@ -1,27 +1,18 @@
use crate::{
cli::{
opt::{CommonOpt, Format, LaunchSubcommand, SessionOutput},
ExitCode, ExitCodeError,
CliSession, ExitCode, ExitCodeError,
},
core::{
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{Client, Transport, TransportReadHalf, TransportWriteHalf},
session::{Session, SessionFile},
utils,
client::{self, Session, SessionInfo, SessionInfoFile},
server::RelayServer,
},
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};
use log::*;
use std::{marker::Unpin, path::Path, string::FromUtf8Error, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
process::Command,
runtime::{Handle, Runtime},
sync::{broadcast, mpsc, oneshot, Mutex},
time::Duration,
};
use std::{path::Path, string::FromUtf8Error};
use tokio::{io, process::Command, runtime::Runtime, time::Duration};
#[derive(Debug, Display, Error, From)]
pub enum Error {
@ -44,12 +35,6 @@ impl ExitCodeError for Error {
}
}
/// Represents state associated with a connection
#[derive(Default)]
struct ConnState {
processes: Vec<usize>,
}
pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
let rt = Runtime::new()?;
let session_output = cmd.session;
@ -68,7 +53,7 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
match session_output {
SessionOutput::File => {
debug!("Outputting session to {:?}", session_file);
rt.block_on(async { SessionFile::new(session_file, session).save().await })?
rt.block_on(async { SessionInfoFile::new(session_file, session).save().await })?
}
SessionOutput::Keep => {
debug!("Entering interactive loop over stdin");
@ -139,54 +124,27 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
Ok(())
}
async fn keep_loop(session: Session, format: Format, duration: Duration) -> io::Result<()> {
use crate::cli::subcommand::action::inner;
match Client::tcp_connect_timeout(session, duration).await {
Ok(client) => {
let config = match format {
Format::Json => inner::LoopConfig::Json,
Format::Shell => inner::LoopConfig::Shell,
};
inner::interactive_loop(client, utils::new_tenant(), config).await
async fn keep_loop(info: SessionInfo, format: Format, duration: Duration) -> io::Result<()> {
match Session::tcp_connect_timeout(info, duration).await {
Ok(session) => {
let cli_session = CliSession::new(client::new_tenant(), session, format);
cli_session.wait().await
}
Err(x) => Err(x),
}
}
#[cfg(unix)]
async fn socket_loop(
socket_path: impl AsRef<Path>,
session: Session,
info: SessionInfo,
duration: Duration,
fail_if_socket_exists: bool,
shutdown_after: Option<Duration>,
) -> io::Result<()> {
// We need to form a connection with the actual server to forward requests
// and responses between connections
debug!("Connecting to {} {}", session.host, session.port);
let mut client = Client::tcp_connect_timeout(session, duration).await?;
// Get a copy of our client's broadcaster so we can have each connection
// subscribe to it for new messages filtered by tenant
debug!("Acquiring client broadcaster");
let broadcaster = client.to_response_broadcaster();
// Spawn task to send to the server requests from connections
debug!("Spawning request forwarding task");
let (req_tx, mut req_rx) = mpsc::channel::<Request>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
tokio::spawn(async move {
while let Some(req) = req_rx.recv().await {
debug!(
"Forwarding request of type{} {} to server",
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = client.fire_timeout(req, duration).await {
error!("Client failed to send request: {:?}", x);
break;
}
}
});
debug!("Connecting to {} {}", info.host, info.port);
let session = Session::tcp_connect_timeout(info, duration).await?;
// Remove the socket file if it already exists
if !fail_if_socket_exists && socket_path.as_ref().exists() {
@ -199,205 +157,17 @@ async fn socket_loop(
debug!("Binding to unix socket: {:?}", socket_path.as_ref());
let listener = tokio::net::UnixListener::bind(socket_path)?;
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
loop {
tokio::select! {
result = listener.accept() => {match result {
Ok((conn, _)) => {
// Create a unique id to associate with the connection since its address
// is not guaranteed to have an identifiable string
let conn_id: usize = rand::random();
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(conn, None).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {:?}> Failed handshake: {}", conn_id, x);
continue;
}
};
let (t_read, t_write) = transport.into_split();
// Used to alert our response task of the connection's tenant name
// based on the first
let (tenant_tx, tenant_rx) = oneshot::channel();
// Create a state we use to keep track of connection-specific data
debug!("<Client @ {}> Initializing internal state", conn_id);
let state = Arc::new(Mutex::new(ConnState::default()));
// Spawn task to continually receive responses from the client that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let res_rx = broadcaster.subscribe();
let state_2 = Arc::clone(&state);
tokio::spawn(async move {
handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await;
});
// Spawn task to continually read requests from connection and forward
// them along to be sent via the client
let req_tx = req_tx.clone();
let ct_2 = Arc::clone(&ct);
tokio::spawn(async move {
ct_2.lock().await.increment();
handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await;
ct_2.lock().await.decrement();
debug!("<Client @ {:?}> Disconnected", conn_id);
});
}
Err(x) => {
error!("Listener failed: {}", x);
break;
}
}}
_ = notify.notified() => {
warn!("Reached shutdown timeout, so terminating");
break;
}
}
}
Ok(())
}
/// Conn::Request -> Client::Fire
async fn handle_conn_incoming<T>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut reader: TransportReadHalf<T>,
tenant_tx: oneshot::Sender<String>,
req_tx: mpsc::Sender<Request>,
) where
T: AsyncRead + Unpin,
{
macro_rules! process_req {
($on_success:expr; $done:expr) => {
match reader.receive::<Request>().await {
Ok(Some(req)) => {
$on_success(&req);
if let Err(x) = req_tx.send(req).await {
error!(
"Failed to pass along request received on unix socket: {:?}",
x
);
$done;
}
}
Ok(None) => $done,
Err(x) => {
error!("Failed to receive request from unix stream: {:?}", x);
$done;
}
}
};
}
let mut tenant = None;
// NOTE: Have to acquire our first request outside our loop since the oneshot
// sender of the tenant's name is consuming
process_req!(
|req: &Request| {
tenant = Some(req.tenant.clone());
if let Err(x) = tenant_tx.send(req.tenant.clone()) {
error!("Failed to send along acquired tenant name: {:?}", x);
return;
}
};
return
);
// Loop and process all additional requests
loop {
process_req!(|_| {}; break);
}
// At this point, we have processed at least one request successfully
// and should have the tenant populated. If we had a failure at the
// beginning, we exit the function early via return.
let tenant = tenant.unwrap();
// Perform cleanup if done by sending a request to kill each running process
// debug!("Cleaning conn {} :: killing process {}", conn_id, id);
if let Err(x) = req_tx
.send(Request::new(
tenant.clone(),
state
.lock()
.await
.processes
.iter()
.map(|id| RequestData::ProcKill { id: *id })
.collect(),
))
let server = RelayServer::initialize(session, listener, shutdown_after).await?;
server
.wait()
.await
{
error!("<Client @ {}> Failed to send kill signals: {}", conn_id, x);
}
}
async fn handle_conn_outgoing<T>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut writer: TransportWriteHalf<T>,
tenant_rx: oneshot::Receiver<String>,
mut res_rx: broadcast::Receiver<Response>,
) where
T: AsyncWrite + Unpin,
{
// We wait for the tenant to be identified by the first request
// before processing responses to be sent back; this is easier
// to implement and yields the same result as we would be dropping
// all responses before we know the tenant
if let Ok(tenant) = tenant_rx.await {
debug!("Associated tenant {} with conn {}", tenant, conn_id);
loop {
match res_rx.recv().await {
// Forward along responses that are for our connection
Ok(res) if res.tenant == tenant => {
debug!(
"Conn {} being sent response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
);
// If a new process was started, we want to capture the id and
// associate it with the connection
let ids = res.payload.iter().filter_map(|x| match x {
ResponseData::ProcStart { id } => Some(*id),
_ => None,
});
for id in ids {
debug!("Tracking proc {} for conn {}", id, conn_id);
state.lock().await.processes.push(id);
}
if let Err(x) = writer.send(res).await {
error!("Failed to send response through unix connection: {}", x);
break;
}
}
// Skip responses that are not for our connection
Ok(_) => {}
Err(x) => {
error!(
"Conn {} failed to receive broadcast response: {}",
conn_id, x
);
break;
}
}
}
}
.map_err(|x| io::Error::new(io::ErrorKind::Other, x))
}
/// Spawns a remote server that listens for requests
///
/// Returns the session associated with the server
async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result<Session, Error> {
async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result<SessionInfo, Error> {
let distant_command = format!(
"{} listen --daemon --host {} {}",
cmd.distant,
@ -417,6 +187,7 @@ async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result<S
distant_command.trim().to_string()
} else {
// TODO: Do we need to try to escape single quotes here because of extra_server_args?
// TODO: Replace this with the ssh2 library shell exec once we integrate that
format!("echo {} | $SHELL -l", distant_command.trim())
},
);
@ -437,11 +208,11 @@ async fn spawn_remote_server(cmd: LaunchSubcommand, _opt: CommonOpt) -> Result<S
// Parse our output for the specific session line
// NOTE: The host provided on this line isn't valid, so we fill it in with our actual host
let out = String::from_utf8(out.stdout)?.trim().to_string();
let mut session = out
let mut info = out
.lines()
.find_map(|line| line.parse::<Session>().ok())
.find_map(|line| line.parse::<SessionInfo>().ok())
.ok_or(Error::MissingSessionData)?;
session.host = cmd.host;
info.host = cmd.host;
Ok(session)
Ok(info)
}

@ -68,7 +68,7 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R
let server = DistantServer::bind(
addr,
cmd.port,
cmd.to_shutdown_after_duration(),
shutdown_after,
cmd.max_msg_capacity as usize,
)
.await?;

@ -297,7 +297,7 @@ impl FromStr for LspHeader {
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LspContent(Map<String, Value>);
fn for_each_mut_string<F1, F2>(value: &mut Value, check: F1, mutate: F2)
fn for_each_mut_string<F1, F2>(value: &mut Value, check: &F1, mutate: &mut F2)
where
F1: Fn(&String) -> bool,
F2: FnMut(&mut String),
@ -309,12 +309,15 @@ where
.for_each(|v| for_each_mut_string(v, check, mutate));
// Mutate keys if necessary
for key in obj.keys() {
if check(key) {
if let Some((key, value)) = obj.remove_entry(key) {
mutate(&mut key);
obj.insert(key, value);
}
let keys: Vec<String> = obj
.keys()
.filter(|k| check(k))
.map(ToString::to_string)
.collect();
for key in keys {
if let Some((mut key, value)) = obj.remove_entry(&key) {
mutate(&mut key);
obj.insert(key, value);
}
}
}
@ -328,7 +331,7 @@ where
fn swap_prefix(obj: &mut Map<String, Value>, old: &str, new: &str) {
let check = |s: &String| s.starts_with(old);
let mutate = |s: &mut String| {
let mut mutate = |s: &mut String| {
if let Some(pos) = s.find(old) {
s.replace_range(pos..old.len(), new);
}
@ -336,15 +339,18 @@ fn swap_prefix(obj: &mut Map<String, Value>, old: &str, new: &str) {
// Mutate values
obj.values_mut()
.for_each(|v| for_each_mut_string(v, check, mutate));
.for_each(|v| for_each_mut_string(v, &check, &mut mutate));
// Mutate keys if necessary
for key in obj.keys() {
if check(key) {
if let Some((key, value)) = obj.remove_entry(key) {
mutate(&mut key);
obj.insert(key, value);
}
let keys: Vec<String> = obj
.keys()
.filter(|k| check(k))
.map(ToString::to_string)
.collect();
for key in keys {
if let Some((mut key, value)) = obj.remove_entry(&key) {
mutate(&mut key);
obj.insert(key, value);
}
}
}
@ -528,7 +534,11 @@ mod tests {
fn data_from_buf_reader_should_fail_if_reach_eof_before_received_full_data() {
// No line termination
let err = LspData::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err();
assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err);
assert!(
matches!(err, LspDataParseError::BadHeaderTermination),
"{:?}",
err
);
// Header doesn't finish
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
@ -1100,7 +1110,7 @@ mod tests {
#[test]
fn content_convert_distant_scheme_to_local_should_convert_keys_and_values() {
let content = LspContent(make_obj!({
let mut content = LspContent(make_obj!({
"distant://key1": "file://value1",
"file://key2": "distant://value2",
"key3": ["file://value3", "distant://value4"],

@ -21,6 +21,7 @@ impl RemoteLspProcess {
/// Spawns the specified process on the remote machine using the given session, treating
/// the process like an LSP server
pub async fn spawn<T>(
tenant: String,
session: Session<T>,
cmd: String,
args: Vec<String>,
@ -28,7 +29,7 @@ impl RemoteLspProcess {
where
T: DataStream + 'static,
{
let mut inner = RemoteProcess::spawn(session, cmd, args).await?;
let mut inner = RemoteProcess::spawn(tenant, session, cmd, args).await?;
let stdin = inner.stdin.take().map(RemoteLspStdin::new);
let stdout = inner.stdout.take().map(RemoteLspStdout::new);
let stderr = inner.stderr.take().map(RemoteLspStderr::new);

@ -3,14 +3,7 @@ mod process;
mod session;
mod utils;
// TODO: Make wrappers around a connection to facilitate the types
// of engagements
//
// 1. Command -> Single request/response through a future
// 2. Proxy -> Does proc-run and waits until proc-done received,
// exposing a sender for stdin and receivers for stdout/stderr,
// and supporting a future await for completion with exit code
// 3.
pub use lsp::*;
pub use process::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
pub use session::*;
pub(crate) use utils::new_tenant;

@ -1,5 +1,5 @@
use crate::core::{
client::{utils, Session},
client::Session,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{DataStream, TransportError},
@ -63,6 +63,7 @@ pub struct RemoteProcess {
impl RemoteProcess {
/// Spawns the specified process on the remote machine using the given session
pub async fn spawn<T>(
tenant: String,
mut session: Session<T>,
cmd: String,
args: Vec<String>,
@ -70,8 +71,6 @@ impl RemoteProcess {
where
T: DataStream + 'static,
{
let tenant = utils::new_tenant();
// Submit our run request and wait for a response
let res = session
.send(Request::new(
@ -127,7 +126,10 @@ impl RemoteProcess {
/// Waits for the process to terminate, returning the success status and an optional exit code
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
self.res_task.await?
match tokio::try_join!(self.req_task, self.res_task) {
Ok((_, res)) => res,
Err(x) => Err(RemoteProcessError::from(x)),
}
}
/// Aborts the process by forcing its response task to shutdown, which means that a call

@ -50,7 +50,7 @@ where
impl Session<InmemoryStream> {
/// Creates a session around an inmemory transport
pub async fn from_inmemory_transport(transport: Transport<InmemoryStream>) -> io::Result<Self> {
Self::inner_connect(transport).await
Self::initialize(transport).await
}
}
@ -67,7 +67,7 @@ impl Session<TcpStream> {
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"))
);
Self::inner_connect(transport).await
Self::initialize(transport).await
}
/// Connect to a remote TCP server, timing out after duration has passed
@ -93,7 +93,7 @@ impl Session<tokio::net::UnixStream> {
.map(|x| format!("{:?}", x))
.unwrap_or_else(|_| String::from("???"))
);
Self::inner_connect(transport).await
Self::initialize(transport).await
}
/// Connect to a proxy unix socket, timing out after duration has passed
@ -112,8 +112,8 @@ impl<T> Session<T>
where
T: DataStream,
{
/// Establishes a connection using the provided transport
async fn inner_connect(transport: Transport<T>) -> io::Result<Self> {
/// Initializes a session using the provided transport
pub async fn initialize(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast, init_broadcast_receiver) =
@ -243,14 +243,13 @@ mod tests {
use crate::core::{
constants::test::TENANT,
data::{RequestData, ResponseData},
net::test::make_transport_pair,
};
use std::time::Duration;
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = make_transport_pair();
let mut session = Session::inner_connect(t1).await.unwrap();
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).await.unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
@ -270,8 +269,8 @@ mod tests {
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = make_transport_pair();
let mut session = Session::inner_connect(t1).await.unwrap();
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).await.unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.send_timeout(req, Duration::from_millis(30)).await {
@ -285,8 +284,8 @@ mod tests {
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = make_transport_pair();
let mut session = Session::inner_connect(t1).await.unwrap();
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).await.unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.fire(req).await {

@ -0,0 +1,48 @@
use super::DataStream;
use std::{future::Future, pin::Pin};
use tokio::{
io,
net::{TcpListener, TcpStream},
};
/// Represents a type that has a listen interface
pub trait Listener: Send + Sync {
type Conn: DataStream;
/// Async function that accepts a new connection, returning `Ok(Self::Conn)`
/// upon receiving the next connection
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
where
Self: Sync + 'a;
}
impl Listener for TcpListener {
type Conn = TcpStream;
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
where
Self: Sync + 'a,
{
async fn accept(_self: &TcpListener) -> io::Result<TcpStream> {
_self.accept().await.map(|(stream, _)| stream)
}
Box::pin(accept(self))
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Conn = tokio::net::UnixStream;
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
where
Self: Sync + 'a,
{
async fn accept(_self: &tokio::net::UnixListener) -> io::Result<tokio::net::UnixStream> {
_self.accept().await.map(|(stream, _)| stream)
}
Box::pin(accept(self))
}
}

@ -1,4 +1,7 @@
mod listener;
mod transport;
pub use listener::Listener;
pub use transport::*;
// Re-export commonly-used orion structs

@ -1,6 +1,7 @@
use super::DataStream;
use super::{DataStream, SecretKey, Transport};
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::{
@ -118,10 +119,146 @@ impl DataStream for InmemoryStream {
type Write = InmemoryStreamWriteHalf;
fn to_connection_tag(&self) -> String {
String::from("test-stream")
String::from("inmemory-stream")
}
fn into_split(self) -> (Self::Read, Self::Write) {
(self.incoming, self.outgoing)
}
}
impl Transport<InmemoryStream> {
/// Produces a pair of inmemory transports that are connected to each other with matching
/// auth and encryption keys
///
/// Sets the buffer for message passing for each underlying stream to the given buffer size
pub fn pair(buffer: usize) -> (Transport<InmemoryStream>, Transport<InmemoryStream>) {
let auth_key = Arc::new(SecretKey::default());
let crypt_key = Arc::new(SecretKey::default());
let (a, b) = InmemoryStream::pair(buffer);
let a = Transport::new(a, Some(Arc::clone(&auth_key)), Arc::clone(&crypt_key));
let b = Transport::new(b, Some(auth_key), crypt_key);
(a, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[test]
fn to_connection_tag_should_be_hardcoded_string() {
let (_, _, stream) = InmemoryStream::make(1);
assert_eq!(stream.to_connection_tag(), "inmemory-stream");
}
#[tokio::test]
async fn make_should_return_sender_that_sends_data_to_stream() {
let (tx, _, mut stream) = InmemoryStream::make(3);
tx.send(b"test msg 1".to_vec()).await.unwrap();
tx.send(b"test msg 2".to_vec()).await.unwrap();
tx.send(b"test msg 3".to_vec()).await.unwrap();
// Should get data matching a singular message
let mut buf = [0; 256];
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 1");
// Next call would get the second message
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 2");
// When the last of the senders is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(tx);
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 3");
let len = stream.read(&mut buf).await.unwrap();
assert_eq!(len, 0, "Unexpectedly got more data");
}
#[tokio::test]
async fn make_should_return_receiver_that_receives_data_from_stream() {
let (_, mut rx, mut stream) = InmemoryStream::make(3);
stream.write_all(b"test msg 1").await.unwrap();
stream.write_all(b"test msg 2").await.unwrap();
stream.write_all(b"test msg 3").await.unwrap();
// Should get data matching a singular message
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
// Next call would get the second message
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
// When the stream is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(stream);
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
}
#[tokio::test]
async fn into_split_should_provide_a_read_half_that_receives_from_sender() {
let (tx, _, stream) = InmemoryStream::make(3);
let (mut read_half, _) = stream.into_split();
tx.send(b"test msg 1".to_vec()).await.unwrap();
tx.send(b"test msg 2".to_vec()).await.unwrap();
tx.send(b"test msg 3".to_vec()).await.unwrap();
// Should get data matching a singular message
let mut buf = [0; 256];
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 1");
// Next call would get the second message
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 2");
// When the last of the senders is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(tx);
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(&buf[..len], b"test msg 3");
let len = read_half.read(&mut buf).await.unwrap();
assert_eq!(len, 0, "Unexpectedly got more data");
}
#[tokio::test]
async fn into_split_should_provide_a_write_half_that_sends_to_receiver() {
let (_, mut rx, stream) = InmemoryStream::make(3);
let (_, mut write_half) = stream.into_split();
write_half.write_all(b"test msg 1").await.unwrap();
write_half.write_all(b"test msg 2").await.unwrap();
write_half.write_all(b"test msg 3").await.unwrap();
// Should get data matching a singular message
assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
// Next call would get the second message
assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
// When the stream is dropped, we should still get
// the rest of the data that was sent first before getting
// an indicator that there is no more data
drop(write_half);
assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
}
}

@ -346,31 +346,25 @@ where
/// Test utilities
#[cfg(test)]
pub mod test {
use super::*;
use crate::core::constants::test::BUFFER_SIZE;
use crate::net::InmemoryStream;
use orion::aead::SecretKey;
impl Transport<InmemoryStream> {
/// Makes a connected pair of transports with matching crypt keys and using the provided
/// auth keys
pub fn make_transport_pair_with_auth_keys(
pub fn make_pair_with_auth_keys(
ak1: Option<Arc<SecretKey>>,
ak2: Option<Arc<SecretKey>>,
) -> (Transport<InmemoryStream>, Transport<InmemoryStream>) {
let crypt_key = Arc::new(SecretKey::default());
let (a, b) = InmemoryStream::pair(BUFFER_SIZE);
let (a, b) = InmemoryStream::pair(crate::core::constants::test::BUFFER_SIZE);
let a = Transport::new(a, ak1, Arc::clone(&crypt_key));
let b = Transport::new(b, ak2, crypt_key);
(a, b)
}
/// Makes a connected pair of transports with matching auth and crypt keys
pub fn make_transport_pair() -> (Transport<InmemoryStream>, Transport<InmemoryStream>) {
let auth_key = Arc::new(SecretKey::default());
make_transport_pair_with_auth_keys(Some(Arc::clone(&auth_key)), Some(auth_key))
/// using test buffer size
pub fn make_pair() -> (Transport<InmemoryStream>, Transport<InmemoryStream>) {
Self::pair(crate::core::constants::test::BUFFER_SIZE)
}
}
@ -380,8 +374,6 @@ mod tests {
use crate::core::constants::test::BUFFER_SIZE;
use std::io;
use test::make_transport_pair_with_auth_keys;
#[tokio::test]
async fn transport_from_handshake_should_fail_if_connection_reached_eof() {
// Cause nothing left incoming to stream by _
@ -462,7 +454,7 @@ mod tests {
#[tokio::test]
async fn transport_should_be_able_to_send_encrypted_data_to_other_side_to_decrypt() {
// Make two transports with no auth keys
let (mut src, mut dst) = make_transport_pair_with_auth_keys(None, None);
let (mut src, mut dst) = Transport::make_pair_with_auth_keys(None, None);
src.send("some data").await.expect("Failed to send data");
let data = dst
@ -480,7 +472,7 @@ mod tests {
// Make two transports with same auth keys
let (mut src, mut dst) =
make_transport_pair_with_auth_keys(Some(Arc::clone(&auth_key)), Some(auth_key));
Transport::make_pair_with_auth_keys(Some(Arc::clone(&auth_key)), Some(auth_key));
src.send("some data").await.expect("Failed to send data");
let data = dst
@ -495,7 +487,7 @@ mod tests {
#[tokio::test]
async fn transport_receive_should_fail_if_auth_key_differs_from_other_end() {
// Make two transports with different auth keys
let (mut src, mut dst) = make_transport_pair_with_auth_keys(
let (mut src, mut dst) = Transport::make_pair_with_auth_keys(
Some(Arc::new(SecretKey::default())),
Some(Arc::new(SecretKey::default())),
);
@ -511,7 +503,7 @@ mod tests {
async fn transport_receive_should_fail_if_has_auth_key_while_sender_did_not_use_one() {
// Make two transports with different auth keys
let (mut src, mut dst) =
make_transport_pair_with_auth_keys(None, Some(Arc::new(SecretKey::default())));
Transport::make_pair_with_auth_keys(None, Some(Arc::new(SecretKey::default())));
src.send("some data").await.expect("Failed to send data");
@ -529,7 +521,7 @@ mod tests {
async fn transport_receive_should_fail_if_has_no_auth_key_while_sender_used_one() {
// Make two transports with different auth keys
let (mut src, mut dst) =
make_transport_pair_with_auth_keys(Some(Arc::new(SecretKey::default())), None);
Transport::make_pair_with_auth_keys(Some(Arc::new(SecretKey::default())), None);
src.send("some data").await.expect("Failed to send data");
match dst.receive::<String>().await {

@ -3,14 +3,13 @@ use crate::core::{
data::{
self, DirEntry, FileType, Request, RequestData, Response, ResponseData, RunningProcess,
},
server::state::{Process, State},
server::distant::state::{Process, State},
};
use derive_more::{Display, Error, From};
use futures::future;
use log::*;
use std::{
env,
net::SocketAddr,
path::{Path, PathBuf},
process::Stdio,
sync::Arc,
@ -24,7 +23,7 @@ use tokio::{
use walkdir::WalkDir;
pub type Reply = mpsc::Sender<Response>;
type HState = Arc<Mutex<State<SocketAddr>>>;
type HState = Arc<Mutex<State>>;
#[derive(Debug, Display, Error, From)]
pub enum ServerError {
@ -43,14 +42,14 @@ impl From<ServerError> for ResponseData {
/// Processes the provided request, sending replies using the given sender
pub(super) async fn process(
addr: SocketAddr,
conn_id: usize,
state: HState,
req: Request,
tx: Reply,
) -> Result<(), mpsc::error::SendError<Response>> {
async fn inner(
tenant: Arc<String>,
addr: SocketAddr,
conn_id: usize,
state: HState,
data: RequestData,
tx: Reply,
@ -76,7 +75,7 @@ pub(super) async fn process(
RequestData::Exists { path } => exists(path).await,
RequestData::Metadata { path, canonicalize } => metadata(path, canonicalize).await,
RequestData::ProcRun { cmd, args } => {
proc_run(tenant.to_string(), addr, state, tx, cmd, args).await
proc_run(tenant.to_string(), conn_id, state, tx, cmd, args).await
}
RequestData::ProcKill { id } => proc_kill(state, id).await,
RequestData::ProcStdin { id, data } => proc_stdin(state, id, data).await,
@ -94,7 +93,7 @@ pub(super) async fn process(
let state_2 = Arc::clone(&state);
let tx_2 = tx.clone();
payload_tasks.push(tokio::spawn(async move {
match inner(tenant_2, addr, state_2, data, tx_2).await {
match inner(tenant_2, conn_id, state_2, data, tx_2).await {
Ok(data) => data,
Err(x) => ResponseData::from(x),
}
@ -114,8 +113,8 @@ pub(super) async fn process(
let res = Response::new(req.tenant, Some(req.id), payload);
debug!(
"<Client @ {}> Sending response of type{} {}",
addr,
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
@ -358,7 +357,7 @@ async fn metadata(path: PathBuf, canonicalize: bool) -> Result<ResponseData, Ser
async fn proc_run(
tenant: String,
addr: SocketAddr,
conn_id: usize,
state: HState,
tx: Reply,
cmd: String,
@ -389,8 +388,8 @@ async fn proc_run(
vec![ResponseData::ProcStdout { id, data }],
);
debug!(
"<Client @ {}> Sending response of type{} {}",
addr,
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
@ -430,8 +429,8 @@ async fn proc_run(
vec![ResponseData::ProcStderr { id, data }],
);
debug!(
"<Client @ {}> Sending response of type{} {}",
addr,
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
@ -491,8 +490,8 @@ async fn proc_run(
vec![ResponseData::ProcDone { id, success, code }]
);
debug!(
"<Client @ {}> Sending response of type{} {}",
addr,
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
@ -503,8 +502,8 @@ async fn proc_run(
Err(x) => {
let res = Response::new(tenant.as_str(), None, vec![ResponseData::from(x)]);
debug!(
"<Client @ {}> Sending response of type{} {}",
addr,
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
@ -533,8 +532,8 @@ async fn proc_run(
id, success: false, code: None
}]);
debug!(
"<Client @ {}> Sending response of type{} {}",
addr,
"<Conn @ {}> Sending response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string()
);
@ -556,7 +555,7 @@ async fn proc_run(
stdin_tx,
kill_tx,
};
state.lock().await.push_process(addr, process);
state.lock().await.push_process(conn_id, process);
Ok(ResponseData::ProcStart { id })
}
@ -609,3 +608,676 @@ async fn system_info() -> Result<ResponseData, ServerError> {
main_separator: std::path::MAIN_SEPARATOR,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::{NamedTempFile, TempDir};
fn setup(
buffer: usize,
) -> (
usize,
Arc<Mutex<State>>,
mpsc::Sender<Response>,
mpsc::Receiver<Response>,
) {
let (tx, rx) = mpsc::channel(buffer);
(
rand::random(),
Arc::new(Mutex::new(State::default())),
tx,
rx,
)
}
/// Create a temporary path that does not exist
fn temppath() -> PathBuf {
// Deleted when dropped
NamedTempFile::new().unwrap().into_temp_path().to_path_buf()
}
#[tokio::test]
async fn file_read_should_send_error_if_fails_to_read_file() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a file and then delete it, keeping just its path
let path = temppath();
let req = Request::new("test-tenant", vec![RequestData::FileRead { path }]);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
}
#[tokio::test]
async fn file_read_should_send_blob_with_file_contents() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary file and fill it with some contents
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"some file contents").unwrap();
let req = Request::new(
"test-tenant",
vec![RequestData::FileRead {
path: file.path().to_path_buf(),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::Blob { data } => assert_eq!(data, b"some file contents"),
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn file_read_text_should_send_error_if_fails_to_read_file() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a file and then delete it, keeping just its path
let path = temppath();
let req = Request::new(
"test-tenant",
vec![RequestData::FileReadText { path: path }],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
}
#[tokio::test]
async fn file_read_text_should_send_text_with_file_contents() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary file and fill it with some contents
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"some file contents").unwrap();
let req = Request::new(
"test-tenant",
vec![RequestData::FileReadText {
path: file.path().to_path_buf(),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::Text { data } => assert_eq!(data, "some file contents"),
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn file_write_should_send_error_if_fails_to_write_file() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary path and add to it to ensure that there are
// extra components that don't exist to cause writing to fail
let path = temppath().join("some_file");
let req = Request::new(
"test-tenant",
vec![RequestData::FileWrite {
path: path.clone(),
data: b"some text".to_vec(),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we didn't actually create the file
assert!(!path.exists(), "File created unexpectedly");
}
#[tokio::test]
async fn file_write_should_send_ok_when_successful() {
let (conn_id, state, tx, mut rx) = setup(1);
// Path should point to a file that does not exist, but all
// other components leading up to it do
let path = temppath();
let req = Request::new(
"test-tenant",
vec![RequestData::FileWrite {
path: path.clone(),
data: b"some text".to_vec(),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Ok),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we actually did create the file
// with the associated contents
assert!(path.exists(), "File not actually created");
assert_eq!(tokio::fs::read_to_string(path).await.unwrap(), "some text");
}
#[tokio::test]
async fn file_write_text_should_send_error_if_fails_to_write_file() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary path and add to it to ensure that there are
// extra components that don't exist to cause writing to fail
let path = temppath().join("some_file");
let req = Request::new(
"test-tenant",
vec![RequestData::FileWriteText {
path: path.clone(),
text: String::from("some text"),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we didn't actually create the file
assert!(!path.exists(), "File created unexpectedly");
}
#[tokio::test]
async fn file_write_text_should_send_ok_when_successful() {
let (conn_id, state, tx, mut rx) = setup(1);
// Path should point to a file that does not exist, but all
// other components leading up to it do
let path = temppath();
let req = Request::new(
"test-tenant",
vec![RequestData::FileWriteText {
path: path.clone(),
text: String::from("some text"),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Ok),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we actually did create the file
// with the associated contents
assert!(path.exists(), "File not actually created");
assert_eq!(tokio::fs::read_to_string(path).await.unwrap(), "some text");
}
#[tokio::test]
async fn file_append_should_send_error_if_fails_to_create_file() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary path and add to it to ensure that there are
// extra components that don't exist to cause writing to fail
let path = temppath().join("some_file");
let req = Request::new(
"test-tenant",
vec![RequestData::FileAppend {
path: path.to_path_buf(),
data: b"some extra contents".to_vec(),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we didn't actually create the file
assert!(!path.exists(), "File created unexpectedly");
}
#[tokio::test]
async fn file_append_should_send_ok_when_successful() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary file and fill it with some contents
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"some file contents").unwrap();
let req = Request::new(
"test-tenant",
vec![RequestData::FileAppend {
path: file.path().to_path_buf(),
data: b"some extra contents".to_vec(),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Ok),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we actually did append to the file
assert_eq!(
tokio::fs::read_to_string(file.path()).await.unwrap(),
"some file contentssome extra contents"
);
}
#[tokio::test]
async fn file_append_text_should_send_error_if_fails_to_create_file() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary path and add to it to ensure that there are
// extra components that don't exist to cause writing to fail
let path = temppath().join("some_file");
let req = Request::new(
"test-tenant",
vec![RequestData::FileAppendText {
path: path.to_path_buf(),
text: String::from("some extra contents"),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we didn't actually create the file
assert!(!path.exists(), "File created unexpectedly");
}
#[tokio::test]
async fn file_append_text_should_send_ok_when_successful() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create a temporary file and fill it with some contents
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"some file contents").unwrap();
let req = Request::new(
"test-tenant",
vec![RequestData::FileAppendText {
path: file.path().to_path_buf(),
text: String::from("some extra contents"),
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Ok),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that we actually did append to the file
assert_eq!(
tokio::fs::read_to_string(file.path()).await.unwrap(),
"some file contentssome extra contents"
);
}
#[tokio::test]
async fn dir_read_should_send_error_if_directory_does_not_exist() {
let (conn_id, state, tx, mut rx) = setup(1);
let path = temppath();
let req = Request::new(
"test-tenant",
vec![RequestData::DirRead {
path,
depth: 0,
absolute: false,
canonicalize: false,
include_root: false,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
}
// /root/
// /root/file1
// /root/sub1/
// /root/sub1/file2
async fn setup_dir() -> TempDir {
let root_dir = TempDir::new().unwrap();
let file1 = root_dir.path().join("file1");
let sub1 = root_dir.path().join("sub1");
let file2 = sub1.join("file2");
tokio::fs::write(file1.as_path(), "").await.unwrap();
tokio::fs::create_dir(sub1.as_path()).await.unwrap();
tokio::fs::write(file2.as_path(), "").await.unwrap();
root_dir
}
#[tokio::test]
async fn dir_read_should_support_depth_limits() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create directory with some nested items
let root_dir = setup_dir().await;
let req = Request::new(
"test-tenant",
vec![RequestData::DirRead {
path: root_dir.path().to_path_buf(),
depth: 1,
absolute: false,
canonicalize: false,
include_root: false,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::DirEntries { entries, .. } => {
assert_eq!(entries.len(), 2, "Wrong number of entries found");
assert_eq!(entries[0].file_type, FileType::File);
assert_eq!(entries[0].path, Path::new("file1"));
assert_eq!(entries[0].depth, 1);
assert_eq!(entries[1].file_type, FileType::Dir);
assert_eq!(entries[1].path, Path::new("sub1"));
assert_eq!(entries[1].depth, 1);
}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn dir_read_should_support_unlimited_depth_using_zero() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create directory with some nested items
let root_dir = setup_dir().await;
let req = Request::new(
"test-tenant",
vec![RequestData::DirRead {
path: root_dir.path().to_path_buf(),
depth: 0,
absolute: false,
canonicalize: false,
include_root: false,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::DirEntries { entries, .. } => {
assert_eq!(entries.len(), 3, "Wrong number of entries found");
assert_eq!(entries[0].file_type, FileType::File);
assert_eq!(entries[0].path, Path::new("file1"));
assert_eq!(entries[0].depth, 1);
assert_eq!(entries[1].file_type, FileType::Dir);
assert_eq!(entries[1].path, Path::new("sub1"));
assert_eq!(entries[1].depth, 1);
assert_eq!(entries[2].file_type, FileType::File);
assert_eq!(entries[2].path, Path::new("sub1").join("file2"));
assert_eq!(entries[2].depth, 2);
}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn dir_read_should_support_including_directory_in_returned_entries() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create directory with some nested items
let root_dir = setup_dir().await;
let req = Request::new(
"test-tenant",
vec![RequestData::DirRead {
path: root_dir.path().to_path_buf(),
depth: 1,
absolute: false,
canonicalize: false,
include_root: true,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::DirEntries { entries, .. } => {
assert_eq!(entries.len(), 3, "Wrong number of entries found");
// NOTE: Root entry is always absolute, resolved path
assert_eq!(entries[0].file_type, FileType::Dir);
assert_eq!(entries[0].path, root_dir.path().canonicalize().unwrap());
assert_eq!(entries[0].depth, 0);
assert_eq!(entries[1].file_type, FileType::File);
assert_eq!(entries[1].path, Path::new("file1"));
assert_eq!(entries[1].depth, 1);
assert_eq!(entries[2].file_type, FileType::Dir);
assert_eq!(entries[2].path, Path::new("sub1"));
assert_eq!(entries[2].depth, 1);
}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn dir_read_should_support_returning_absolute_paths() {
let (conn_id, state, tx, mut rx) = setup(1);
// Create directory with some nested items
let root_dir = setup_dir().await;
let req = Request::new(
"test-tenant",
vec![RequestData::DirRead {
path: root_dir.path().to_path_buf(),
depth: 1,
absolute: true,
canonicalize: false,
include_root: false,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
match &res.payload[0] {
ResponseData::DirEntries { entries, .. } => {
assert_eq!(entries.len(), 2, "Wrong number of entries found");
let root_path = root_dir.path().canonicalize().unwrap();
assert_eq!(entries[0].file_type, FileType::File);
assert_eq!(entries[0].path, root_path.join("file1"));
assert_eq!(entries[0].depth, 1);
assert_eq!(entries[1].file_type, FileType::Dir);
assert_eq!(entries[1].path, root_path.join("sub1"));
assert_eq!(entries[1].depth, 1);
}
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
#[ignore]
async fn dir_read_should_support_returning_canonicalized_paths() {
todo!("Figure out best way to support symlink tests");
}
#[tokio::test]
async fn dir_create_should_send_error_if_fails() {
let (conn_id, state, tx, mut rx) = setup(1);
// Make a path that has multiple non-existent components
// so the creation will fail
let root_dir = setup_dir().await;
let path = root_dir.path().join("nested").join("new-dir");
let req = Request::new(
"test-tenant",
vec![RequestData::DirCreate {
path: path.to_path_buf(),
all: false,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Error(_)),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that the directory was not actually created
assert!(!path.exists(), "Path unexpectedly exists");
}
#[tokio::test]
async fn dir_create_should_send_ok_when_successful() {
let (conn_id, state, tx, mut rx) = setup(1);
let root_dir = setup_dir().await;
let path = root_dir.path().join("new-dir");
let req = Request::new(
"test-tenant",
vec![RequestData::DirCreate {
path: path.to_path_buf(),
all: false,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Ok),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that the directory was actually created
assert!(path.exists(), "Directory not created");
}
#[tokio::test]
async fn dir_create_should_support_creating_multiple_dir_components() {
let (conn_id, state, tx, mut rx) = setup(1);
let root_dir = setup_dir().await;
let path = root_dir.path().join("nested").join("new-dir");
let req = Request::new(
"test-tenant",
vec![RequestData::DirCreate {
path: path.to_path_buf(),
all: true,
}],
);
process(conn_id, state, req, tx).await.unwrap();
let res = rx.recv().await.unwrap();
assert_eq!(res.payload.len(), 1, "Wrong payload size");
assert!(
matches!(res.payload[0], ResponseData::Ok),
"Unexpected response: {:?}",
res.payload[0]
);
// Also verify that the directory was actually created
assert!(path.exists(), "Directory not created");
}
}

@ -1,39 +1,37 @@
mod handler;
mod port;
mod state;
mod utils;
pub use port::{PortRange, PortRangeParseError};
use state::State;
use crate::core::{
data::{Request, Response},
net::{SecretKey, Transport, TransportReadHalf, TransportWriteHalf},
server::{
utils::{ConnTracker, ShutdownTask},
PortRange,
},
};
use futures::future::OptionFuture;
use log::*;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use std::{net::IpAddr, sync::Arc};
use tokio::{
io,
net::{tcp, TcpListener, TcpStream},
runtime::Handle,
sync::{mpsc, Mutex, Notify},
sync::{mpsc, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
/// Represents a server that listens for requests, processes them, and sends responses
pub struct Server {
pub struct DistantServer {
port: u16,
state: Arc<Mutex<State<SocketAddr>>>,
auth_key: Arc<SecretKey>,
notify: Arc<Notify>,
conn_task: JoinHandle<()>,
}
impl Server {
impl DistantServer {
/// Bind to an IP address and port from the given range, taking an optional shutdown duration
/// that will shutdown the server if there is no active connection after duration
pub async fn bind(
addr: IpAddr,
port: PortRange,
@ -47,21 +45,19 @@ impl Server {
debug!("Bound to port: {}", port);
// Build our state for the server
let state: Arc<Mutex<State<SocketAddr>>> = Arc::new(Mutex::new(State::default()));
let state: Arc<Mutex<State>> = Arc::new(Mutex::new(State::default()));
let auth_key = Arc::new(SecretKey::default());
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
// Spawn our connection task
let state_2 = Arc::clone(&state);
let auth_key_2 = Arc::clone(&auth_key);
let notify_2 = Arc::clone(&notify);
let conn_task = tokio::spawn(async move {
connection_loop(
listener,
state_2,
state,
auth_key_2,
ct,
notify_2,
tracker,
shutdown,
max_msg_capacity,
)
.await
@ -69,9 +65,7 @@ impl Server {
Ok(Self {
port,
state,
auth_key,
notify,
conn_task,
})
}
@ -90,33 +84,32 @@ impl Server {
pub async fn wait(self) -> Result<(), JoinError> {
self.conn_task.await
}
/// Shutdown the server
pub fn shutdown(&self) {
self.notify.notify_one()
}
}
async fn connection_loop(
listener: TcpListener,
state: Arc<Mutex<State<SocketAddr>>>,
state: Arc<Mutex<State>>,
auth_key: Arc<SecretKey>,
tracker: Arc<Mutex<utils::ConnTracker>>,
notify: Arc<Notify>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
shutdown: OptionFuture<ShutdownTask>,
max_msg_capacity: usize,
) {
loop {
tokio::select! {
result = listener.accept() => {match result {
let inner = async move {
loop {
match listener.accept().await {
Ok((conn, addr)) => {
let conn_id = rand::random();
debug!("<Conn @ {}> Established against {}", conn_id, addr);
if let Err(x) = on_new_conn(
conn,
addr,
conn_id,
Arc::clone(&state),
Arc::clone(&auth_key),
Arc::clone(&tracker),
max_msg_capacity
).await {
tracker.as_ref().map(Arc::clone),
max_msg_capacity,
)
.await
{
error!("<Conn @ {}> Failed handshake: {}", addr, x);
}
}
@ -124,12 +117,15 @@ async fn connection_loop(
error!("Listener failed: {}", x);
break;
}
}}
_ = notify.notified() => {
warn!("Reached shutdown timeout, so terminating");
break;
}
}
};
tokio::select! {
_ = inner => {}
_ = shutdown => {
warn!("Reached shutdown timeout, so terminating");
}
}
}
@ -137,10 +133,10 @@ async fn connection_loop(
/// input and output, returning join handles for the input and output tasks respectively
async fn on_new_conn(
conn: TcpStream,
addr: SocketAddr,
state: Arc<Mutex<State<SocketAddr>>>,
conn_id: usize,
state: Arc<Mutex<State>>,
auth_key: Arc<SecretKey>,
tracker: Arc<Mutex<utils::ConnTracker>>,
tracker: Option<Arc<Mutex<ConnTracker>>>,
max_msg_capacity: usize,
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)> {
// Establish a proper connection via a handshake,
@ -151,23 +147,26 @@ async fn on_new_conn(
// and output concurrently
let (t_read, t_write) = transport.into_split();
let (tx, rx) = mpsc::channel(max_msg_capacity);
let ct_2 = Arc::clone(&tracker);
// Spawn a new task that loops to handle requests from the client
let req_task = tokio::spawn({
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
let f = request_loop(conn_id, Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
ct_2.lock().await.increment();
if let Some(ct) = tracker.as_ref() {
ct.lock().await.increment();
}
f.await;
state.lock().await.cleanup_client(addr).await;
ct_2.lock().await.decrement();
state.lock().await.cleanup_connection(conn_id).await;
if let Some(ct) = tracker.as_ref() {
ct.lock().await.decrement();
}
}
});
// Spawn a new task that loops to handle responses to the client
let res_task = tokio::spawn(async move { response_loop(addr, t_write, rx).await });
let res_task = tokio::spawn(async move { response_loop(conn_id, t_write, rx).await });
Ok((req_task, res_task))
}
@ -175,8 +174,8 @@ async fn on_new_conn(
/// Repeatedly reads in new requests, processes them, and sends their responses to the
/// response loop
async fn request_loop(
addr: SocketAddr,
state: Arc<Mutex<State<SocketAddr>>>,
conn_id: usize,
state: Arc<Mutex<State>>,
mut transport: TransportReadHalf<tcp::OwnedReadHalf>,
tx: mpsc::Sender<Response>,
) {
@ -185,22 +184,23 @@ async fn request_loop(
Ok(Some(req)) => {
debug!(
"<Conn @ {}> Received request of type{} {}",
addr,
conn_id,
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = handler::process(addr, Arc::clone(&state), req, tx.clone()).await {
error!("<Conn @ {}> {}", addr, x);
if let Err(x) = handler::process(conn_id, Arc::clone(&state), req, tx.clone()).await
{
error!("<Conn @ {}> {}", conn_id, x);
break;
}
}
Ok(None) => {
info!("<Conn @ {}> Closed connection", addr);
info!("<Conn @ {}> Closed connection", conn_id);
break;
}
Err(x) => {
error!("<Conn @ {}> {}", addr, x);
error!("<Conn @ {}> {}", conn_id, x);
break;
}
}
@ -209,13 +209,13 @@ async fn request_loop(
/// Repeatedly sends responses out over the wire
async fn response_loop(
addr: SocketAddr,
conn_id: usize,
mut transport: TransportWriteHalf<tcp::OwnedWriteHalf>,
mut rx: mpsc::Receiver<Response>,
) {
while let Some(res) = rx.recv().await {
if let Err(x) = transport.send(res).await {
error!("<Conn @ {}> {}", addr, x);
error!("<Conn @ {}> {}", conn_id, x);
break;
}
}

@ -1,66 +0,0 @@
use derive_more::{Display, Error};
use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
};
/// Represents some range of ports
#[derive(Clone, Debug, Display, PartialEq, Eq)]
#[display(
fmt = "{}{}",
start,
"end.as_ref().map(|end| format!(\"[:{}]\", end)).unwrap_or_default()"
)]
pub struct PortRange {
pub start: u16,
pub end: Option<u16>,
}
impl PortRange {
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
let mut socket_addrs = Vec::new();
let addr = addr.into();
for port in self.start..=self.end.unwrap_or(self.start) {
socket_addrs.push(SocketAddr::from((addr, port)));
}
socket_addrs
}
}
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
pub enum PortRangeParseError {
InvalidPort,
MissingPort,
}
impl FromStr for PortRange {
type Err = PortRangeParseError;
/// Parses PORT into single range or PORT1:PORTN into full range
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut tokens = s.trim().split(':');
let start = tokens
.next()
.ok_or(PortRangeParseError::MissingPort)?
.parse::<u16>()
.map_err(|_| PortRangeParseError::InvalidPort)?;
let end = if let Some(token) = tokens.next() {
Some(
token
.parse::<u16>()
.map_err(|_| PortRangeParseError::InvalidPort)?,
)
} else {
None
};
if tokens.next().is_some() {
return Err(PortRangeParseError::InvalidPort);
}
Ok(Self { start, end })
}
}

@ -1,46 +1,41 @@
use log::*;
use std::{collections::HashMap, fmt::Debug, hash::Hash};
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};
/// Holds state related to multiple clients managed by a server
pub struct State<ClientId>
where
ClientId: Debug + Hash + PartialEq + Eq,
{
#[derive(Default)]
pub struct State {
/// Map of all processes running on the server
pub processes: HashMap<usize, Process>,
/// List of processes that will be killed when a client drops
client_processes: HashMap<ClientId, Vec<usize>>,
client_processes: HashMap<usize, Vec<usize>>,
}
impl<ClientId> State<ClientId>
where
ClientId: Debug + Hash + PartialEq + Eq,
{
impl State {
/// Pushes a new process associated with a client
pub fn push_process(&mut self, client_id: ClientId, process: Process) {
pub fn push_process(&mut self, conn_id: usize, process: Process) {
self.client_processes
.entry(client_id)
.entry(conn_id)
.or_insert(Vec::new())
.push(process.id);
self.processes.insert(process.id, process);
}
/// Cleans up state associated with a particular client
pub async fn cleanup_client(&mut self, client_id: ClientId) {
debug!("<Client @ {:?}> Cleaning up state", client_id);
if let Some(ids) = self.client_processes.remove(&client_id) {
/// Cleans up state associated with a particular connection
pub async fn cleanup_connection(&mut self, conn_id: usize) {
debug!("<Conn @ {:?}> Cleaning up state", conn_id);
if let Some(ids) = self.client_processes.remove(&conn_id) {
for id in ids {
if let Some(process) = self.processes.remove(&id) {
trace!(
"<Client @ {:?}> Requesting proc {} be killed",
client_id,
"<Conn @ {:?}> Requesting proc {} be killed",
conn_id,
process.id
);
if let Err(_) = process.kill_tx.send(()) {
error!(
"Client {} failed to send process {} kill signal",
"Conn {} failed to send process {} kill signal",
id, process.id
);
}
@ -50,18 +45,6 @@ where
}
}
impl<ClientId> Default for State<ClientId>
where
ClientId: Debug + Hash + PartialEq + Eq,
{
fn default() -> Self {
Self {
processes: HashMap::new(),
client_processes: HashMap::new(),
}
}
}
/// Represents an actively-running process
pub struct Process {
/// Id of the process

@ -1,101 +0,0 @@
use log::*;
use std::{sync::Arc, time::Duration};
use tokio::{
runtime::Handle,
sync::{Mutex, Notify},
time::{self, Instant},
};
pub struct ConnTracker {
time: Instant,
cnt: usize,
}
impl ConnTracker {
pub fn new() -> Self {
Self {
time: Instant::now(),
cnt: 0,
}
}
pub fn increment(&mut self) {
self.time = Instant::now();
self.cnt += 1;
}
pub fn decrement(&mut self) {
if self.cnt > 0 {
self.time = Instant::now();
self.cnt -= 1;
}
}
pub fn time_and_cnt(&self) -> (Instant, usize) {
(self.time, self.cnt)
}
pub fn has_exceeded_timeout(&self, duration: Duration) -> bool {
self.cnt == 0 && self.time.elapsed() > duration
}
}
/// Spawns a new task that continues to monitor the time since a
/// connection on the server existed, shutting down the runtime
/// if the time is exceeded
pub fn new_shutdown_task(
handle: Handle,
duration: Option<Duration>,
) -> (Arc<Mutex<ConnTracker>>, Arc<Notify>) {
let ct = Arc::new(Mutex::new(ConnTracker::new()));
let notify = Arc::new(Notify::new());
let ct_2 = Arc::clone(&ct);
let notify_2 = Arc::clone(&notify);
if let Some(duration) = duration {
handle.spawn(async move {
loop {
// Get the time since the last connection joined/left
let (base_time, cnt) = ct_2.lock().await.time_and_cnt();
// If we have no connections left, we want to wait
// until the remaining period has passed and then
// verify that we still have no connections
if cnt == 0 {
// Get the time we should wait based on when the last connection
// was dropped; this closes the gap in the case where we start
// sometime later than exactly duration since the last check
let next_time = base_time + duration;
let wait_duration = next_time
.checked_duration_since(Instant::now())
.unwrap_or_default()
+ Duration::from_millis(1);
// Wait until we've reached our desired duration since the
// last connection was dropped
time::sleep(wait_duration).await;
// If we do have a connection at this point, don't exit
if !ct_2.lock().await.has_exceeded_timeout(duration) {
continue;
}
// Otherwise, we now should exit, which we do by reporting
debug!(
"Shutdown time of {}s has been reached!",
duration.as_secs_f32()
);
notify_2.notify_one();
break;
}
// Otherwise, we just wait the full duration as worst case
// we'll have waited just about the time desired if right
// after waiting starts the last connection is closed
time::sleep(duration).await;
}
});
}
(ct, notify)
}

@ -1,2 +1,8 @@
mod distant;
pub use distant::{PortRange, PortRangeParseError, Server as DistantServer};
mod port;
mod relay;
mod utils;
pub use self::distant::DistantServer;
pub use port::PortRange;
pub use relay::RelayServer;

@ -0,0 +1,180 @@
use derive_more::Display;
use std::{
net::{IpAddr, SocketAddr},
ops::RangeInclusive,
str::FromStr,
};
/// Represents some range of ports
#[derive(Clone, Debug, Display, PartialEq, Eq)]
#[display(
fmt = "{}{}",
start,
"end.as_ref().map(|end| format!(\":{}\", end)).unwrap_or_default()"
)]
pub struct PortRange {
pub start: u16,
pub end: Option<u16>,
}
impl PortRange {
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
let mut socket_addrs = Vec::new();
let addr = addr.into();
for port in self {
socket_addrs.push(SocketAddr::from((addr, port)));
}
socket_addrs
}
}
impl From<RangeInclusive<u16>> for PortRange {
fn from(r: RangeInclusive<u16>) -> Self {
let (start, end) = r.into_inner();
Self {
start,
end: Some(end),
}
}
}
impl<'a> IntoIterator for &'a PortRange {
type Item = u16;
type IntoIter = RangeInclusive<u16>;
fn into_iter(self) -> Self::IntoIter {
self.start..=self.end.unwrap_or(self.start)
}
}
impl IntoIterator for PortRange {
type Item = u16;
type IntoIter = RangeInclusive<u16>;
fn into_iter(self) -> Self::IntoIter {
self.start..=self.end.unwrap_or(self.start)
}
}
impl FromStr for PortRange {
type Err = std::num::ParseIntError;
/// Parses PORT into single range or PORT1:PORTN into full range
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.split_once(':') {
Some((start, end)) => Ok(Self {
start: start.parse()?,
end: Some(end.parse()?),
}),
None => Ok(Self {
start: s.parse()?,
end: None,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_should_properly_reflect_port_range() {
let p = PortRange {
start: 100,
end: None,
};
assert_eq!(p.to_string(), "100");
let p = PortRange {
start: 100,
end: Some(200),
};
assert_eq!(p.to_string(), "100:200");
}
#[test]
fn from_range_inclusive_should_map_to_port_range() {
let p = PortRange::from(100..=200);
assert_eq!(p.start, 100);
assert_eq!(p.end, Some(200));
}
#[test]
fn into_iterator_should_support_port_range() {
let p = PortRange {
start: 1,
end: None,
};
assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1]);
assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1]);
let p = PortRange {
start: 1,
end: Some(3),
};
assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
}
#[test]
fn make_socket_addrs_should_produce_a_socket_addr_per_port() {
let ip_addr = "127.0.0.1".parse::<IpAddr>().unwrap();
let p = PortRange {
start: 1,
end: None,
};
assert_eq!(
p.make_socket_addrs(ip_addr),
vec![SocketAddr::new(ip_addr, 1)]
);
let p = PortRange {
start: 1,
end: Some(3),
};
assert_eq!(
p.make_socket_addrs(ip_addr),
vec![
SocketAddr::new(ip_addr, 1),
SocketAddr::new(ip_addr, 2),
SocketAddr::new(ip_addr, 3),
]
);
}
#[test]
fn parse_should_fail_if_not_starting_with_number() {
assert!("100a".parse::<PortRange>().is_err());
}
#[test]
fn parse_should_fail_if_provided_end_port_that_is_not_a_number() {
assert!("100:200a".parse::<PortRange>().is_err());
}
#[test]
fn parse_should_be_able_to_properly_read_in_port_range() {
let p: PortRange = "100".parse().unwrap();
assert_eq!(
p,
PortRange {
start: 100,
end: None
}
);
let p: PortRange = "100:200".parse().unwrap();
assert_eq!(
p,
PortRange {
start: 100,
end: Some(200)
}
);
}
}

@ -0,0 +1,336 @@
use crate::core::{
client::Session,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{DataStream, Listener, Transport, TransportReadHalf, TransportWriteHalf},
server::utils::{ConnTracker, ShutdownTask},
};
use log::*;
use std::{collections::HashMap, marker::Unpin, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
sync::{broadcast, mpsc, oneshot, Mutex},
task::{JoinError, JoinHandle},
time::Duration,
};
/// Represents a server that relays requests & responses between connections and the
/// actual server
pub struct RelayServer {
forward_task: JoinHandle<()>,
accept_task: JoinHandle<()>,
conns: Arc<Mutex<HashMap<usize, Conn>>>,
}
impl RelayServer {
pub async fn initialize<T1, T2, L>(
mut session: Session<T1>,
listener: L,
shutdown_after: Option<Duration>,
) -> io::Result<Self>
where
T1: DataStream + 'static,
T2: DataStream + Send + 'static,
L: Listener<Conn = T2> + 'static,
{
// Get a copy of our session's broadcaster so we can have each connection
// subscribe to it for new messages filtered by tenant
debug!("Acquiring session broadcaster");
let broadcaster = session.to_response_broadcaster();
// Spawn task to send to the server requests from connections
debug!("Spawning request forwarding task");
let (req_tx, mut req_rx) = mpsc::channel::<Request>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let forward_task = tokio::spawn(async move {
while let Some(req) = req_rx.recv().await {
debug!(
"Forwarding request of type{} {} to server",
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = session.fire(req).await {
error!("Session failed to send request: {:?}", x);
break;
}
}
});
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
let conns = Arc::new(Mutex::new(HashMap::new()));
let conns_2 = Arc::clone(&conns);
let accept_task = tokio::spawn(async move {
let inner = async move {
loop {
match listener.accept().await {
Ok(stream) => {
let result = Conn::initialize(
stream,
req_tx.clone(),
broadcaster.clone(),
tracker.as_ref().map(Arc::clone),
)
.await;
match result {
Ok(conn) => conns_2.lock().await.insert(conn.id(), conn),
Err(x) => {
error!("Failed to initialize connection: {}", x);
continue;
}
};
}
Err(x) => {
debug!("Listener has closed: {}", x);
break;
}
}
}
};
tokio::select! {
_ = inner => {}
_ = shutdown => {
warn!("Reached shutdown timeout, so terminating");
}
}
});
Ok(Self {
forward_task,
accept_task,
conns,
})
}
pub async fn wait(self) -> Result<(), JoinError> {
match tokio::try_join!(self.forward_task, self.accept_task) {
Ok(_) => Ok(()),
Err(x) => Err(x),
}
}
pub async fn abort(&self) {
self.forward_task.abort();
self.accept_task.abort();
self.conns
.lock()
.await
.values()
.for_each(|conn| conn.abort());
}
}
struct Conn {
id: usize,
req_task: JoinHandle<()>,
res_task: JoinHandle<()>,
}
/// Represents state associated with a connection
#[derive(Default)]
struct ConnState {
processes: Vec<usize>,
}
impl Conn {
pub async fn initialize<T>(
stream: T,
req_tx: mpsc::Sender<Request>,
res_broadcaster: broadcast::Sender<Response>,
ct: Option<Arc<Mutex<ConnTracker>>>,
) -> io::Result<Self>
where
T: DataStream + 'static,
{
// Create a unique id to associate with the connection since its address
// is not guaranteed to have an identifiable string
let id: usize = rand::random();
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = Transport::from_handshake(stream, None).await.map_err(|x| {
error!("<Conn @ {}> Failed handshake: {}", id, x);
io::Error::new(io::ErrorKind::Other, x)
})?;
let (t_read, t_write) = transport.into_split();
// Used to alert our response task of the connection's tenant name
// based on the first
let (tenant_tx, tenant_rx) = oneshot::channel();
// Create a state we use to keep track of connection-specific data
debug!("<Conn @ {}> Initializing internal state", id);
let state = Arc::new(Mutex::new(ConnState::default()));
// Spawn task to continually receive responses from the session that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let res_rx = res_broadcaster.subscribe();
let state_2 = Arc::clone(&state);
let res_task = tokio::spawn(async move {
handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await;
});
// Spawn task to continually read requests from connection and forward
// them along to be sent via the session
let req_tx = req_tx.clone();
let req_task = tokio::spawn(async move {
if let Some(ct) = ct.as_ref() {
ct.lock().await.increment();
}
handle_conn_incoming(id, state, t_read, tenant_tx, req_tx).await;
if let Some(ct) = ct.as_ref() {
ct.lock().await.decrement();
}
debug!("<Conn @ {}> Disconnected", id);
});
Ok(Self {
id,
req_task,
res_task,
})
}
/// Id associated with the connection
pub fn id(&self) -> usize {
self.id
}
/// Aborts the connection from the server side
pub fn abort(&self) {
self.req_task.abort();
self.res_task.abort();
}
}
/// Conn::Request -> Session::Fire
async fn handle_conn_incoming<T>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut reader: TransportReadHalf<T>,
tenant_tx: oneshot::Sender<String>,
req_tx: mpsc::Sender<Request>,
) where
T: AsyncRead + Unpin,
{
macro_rules! process_req {
($on_success:expr; $done:expr) => {
match reader.receive::<Request>().await {
Ok(Some(req)) => {
$on_success(&req);
if let Err(x) = req_tx.send(req).await {
error!(
"Failed to pass along request received on unix socket: {:?}",
x
);
$done;
}
}
Ok(None) => $done,
Err(x) => {
error!("Failed to receive request from unix stream: {:?}", x);
$done;
}
}
};
}
let mut tenant = None;
// NOTE: Have to acquire our first request outside our loop since the oneshot
// sender of the tenant's name is consuming
process_req!(
|req: &Request| {
tenant = Some(req.tenant.clone());
if let Err(x) = tenant_tx.send(req.tenant.clone()) {
error!("Failed to send along acquired tenant name: {:?}", x);
return;
}
};
return
);
// Loop and process all additional requests
loop {
process_req!(|_| {}; break);
}
// At this point, we have processed at least one request successfully
// and should have the tenant populated. If we had a failure at the
// beginning, we exit the function early via return.
let tenant = tenant.unwrap();
// Perform cleanup if done by sending a request to kill each running process
// debug!("Cleaning conn {} :: killing process {}", conn_id, id);
if let Err(x) = req_tx
.send(Request::new(
tenant.clone(),
state
.lock()
.await
.processes
.iter()
.map(|id| RequestData::ProcKill { id: *id })
.collect(),
))
.await
{
error!("<Conn @ {}> Failed to send kill signals: {}", conn_id, x);
}
}
async fn handle_conn_outgoing<T>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut writer: TransportWriteHalf<T>,
tenant_rx: oneshot::Receiver<String>,
mut res_rx: broadcast::Receiver<Response>,
) where
T: AsyncWrite + Unpin,
{
// We wait for the tenant to be identified by the first request
// before processing responses to be sent back; this is easier
// to implement and yields the same result as we would be dropping
// all responses before we know the tenant
if let Ok(tenant) = tenant_rx.await {
debug!("Associated tenant {} with conn {}", tenant, conn_id);
loop {
match res_rx.recv().await {
// Forward along responses that are for our connection
Ok(res) if res.tenant == tenant => {
debug!(
"Conn {} being sent response of type{} {}",
conn_id,
if res.payload.len() > 1 { "s" } else { "" },
res.to_payload_type_string(),
);
// If a new process was started, we want to capture the id and
// associate it with the connection
let ids = res.payload.iter().filter_map(|x| match x {
ResponseData::ProcStart { id } => Some(*id),
_ => None,
});
for id in ids {
debug!("Tracking proc {} for conn {}", id, conn_id);
state.lock().await.processes.push(id);
}
if let Err(x) = writer.send(res).await {
error!("Failed to send response through unix connection: {}", x);
break;
}
}
// Skip responses that are not for our connection
Ok(_) => {}
Err(x) => {
error!(
"Conn {} failed to receive broadcast response: {}",
conn_id, x
);
break;
}
}
}
}
}

@ -0,0 +1,291 @@
use futures::future::OptionFuture;
use log::*;
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::{
sync::Mutex,
task::{JoinError, JoinHandle},
time::{self, Instant},
};
/// Task to keep track of a possible server shutdown based on connections
pub struct ShutdownTask {
task: JoinHandle<()>,
tracker: Arc<Mutex<ConnTracker>>,
}
impl Future for ShutdownTask {
type Output = Result<(), JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.task).poll(cx)
}
}
impl ShutdownTask {
/// Given an optional timeout, will either create the shutdown task or not,
/// returning an optional future for the completion of the shutdown task
/// alongside an optional connection tracker
pub fn maybe_initialize(
duration: Option<Duration>,
) -> (OptionFuture<ShutdownTask>, Option<Arc<Mutex<ConnTracker>>>) {
match duration {
Some(duration) => {
let task = Self::initialize(duration);
let tracker = task.tracker();
let task: OptionFuture<_> = Some(task).into();
(task, Some(tracker))
}
None => (None.into(), None),
}
}
/// Spawns a new task that continues to monitor the time since a
/// connection on the server existed, reporting a shutdown to all listeners
/// once the timeout is exceeded
pub fn initialize(duration: Duration) -> Self {
let tracker = Arc::new(Mutex::new(ConnTracker::new()));
let tracker_2 = Arc::clone(&tracker);
let task = tokio::spawn(async move {
loop {
// Get the time since the last connection joined/left
let (base_time, cnt) = tracker_2.lock().await.time_and_cnt();
// If we have no connections left, we want to wait
// until the remaining period has passed and then
// verify that we still have no connections
if cnt == 0 {
// Get the time we should wait based on when the last connection
// was dropped; this closes the gap in the case where we start
// sometime later than exactly duration since the last check
let next_time = base_time + duration;
let wait_duration = next_time
.checked_duration_since(Instant::now())
.unwrap_or_default()
+ Duration::from_millis(1);
// Wait until we've reached our desired duration since the
// last connection was dropped
time::sleep(wait_duration).await;
// If we do have a connection at this point, don't exit
if !tracker_2.lock().await.has_reached_timeout(duration) {
continue;
}
// Otherwise, we now should exit, which we do by reporting
debug!(
"Shutdown time of {}s has been reached!",
duration.as_secs_f32()
);
break;
}
// Otherwise, we just wait the full duration as worst case
// we'll have waited just about the time desired if right
// after waiting starts the last connection is closed
time::sleep(duration).await;
}
});
Self { task, tracker }
}
/// Produces a new copy of the connection tracker associated with the shutdown manager
pub fn tracker(&self) -> Arc<Mutex<ConnTracker>> {
Arc::clone(&self.tracker)
}
}
pub struct ConnTracker {
time: Instant,
cnt: usize,
}
impl ConnTracker {
pub fn new() -> Self {
Self {
time: Instant::now(),
cnt: 0,
}
}
pub fn increment(&mut self) {
self.time = Instant::now();
self.cnt += 1;
}
pub fn decrement(&mut self) {
if self.cnt > 0 {
self.time = Instant::now();
self.cnt -= 1;
}
}
fn time_and_cnt(&self) -> (Instant, usize) {
(self.time, self.cnt)
}
fn has_reached_timeout(&self, duration: Duration) -> bool {
self.cnt == 0 && self.time.elapsed() >= duration
}
}
#[cfg(test)]
mod tsets {
use super::*;
use std::thread;
#[tokio::test]
async fn shutdown_task_should_not_resolve_if_has_connection_regardless_of_time() {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
task.tracker().lock().await.increment();
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(15)).await;
assert!(
futures::poll!(task).is_pending(),
"Shutdown task unexpectedly completed"
);
}
#[tokio::test]
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration() {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(15)).await;
assert!(
futures::poll!(task).is_ready(),
"Shutdown task unexpectedly pending"
);
}
#[tokio::test]
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration_after_connection_removed(
) {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
task.tracker().lock().await.increment();
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(15)).await;
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
task.tracker().lock().await.decrement();
time::sleep(Duration::from_millis(15)).await;
assert!(
futures::poll!(task).is_ready(),
"Shutdown task unexpectedly pending"
);
}
#[tokio::test]
async fn shutdown_task_should_not_resolve_before_minimum_duration() {
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
assert!(
futures::poll!(&mut task).is_pending(),
"Shutdown task unexpectedly completed"
);
time::sleep(Duration::from_millis(5)).await;
assert!(
futures::poll!(task).is_pending(),
"Shutdown task unexpectedly completed"
);
}
#[test]
fn conn_tracker_should_update_time_when_incremented() {
let mut tracker = ConnTracker::new();
let (old_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
// Wait to ensure that the new time will be different
thread::sleep(Duration::from_millis(1));
tracker.increment();
let (new_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 1);
assert!(new_time > old_time);
}
#[test]
fn conn_tracker_should_update_time_when_decremented() {
let mut tracker = ConnTracker::new();
tracker.increment();
let (old_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 1);
// Wait to ensure that the new time will be different
thread::sleep(Duration::from_millis(1));
tracker.decrement();
let (new_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
assert!(new_time > old_time);
}
#[test]
fn conn_tracker_should_not_update_time_when_decremented_if_at_zero_already() {
let mut tracker = ConnTracker::new();
let (old_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
// Wait to ensure that the new time would be different if updated
thread::sleep(Duration::from_millis(1));
tracker.decrement();
let (new_time, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
assert!(new_time == old_time);
}
#[test]
fn conn_tracker_should_report_timeout_reached_when_time_has_elapsed_and_no_connections() {
let tracker = ConnTracker::new();
let (_, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 0);
// Wait to ensure that the new time would be different if updated
thread::sleep(Duration::from_millis(1));
assert!(tracker.has_reached_timeout(Duration::from_millis(1)));
}
#[test]
fn conn_tracker_should_not_report_timeout_reached_when_time_has_elapsed_but_has_connections() {
let mut tracker = ConnTracker::new();
tracker.increment();
let (_, cnt) = tracker.time_and_cnt();
assert_eq!(cnt, 1);
// Wait to ensure that the new time would be different if updated
thread::sleep(Duration::from_millis(1));
assert!(!tracker.has_reached_timeout(Duration::from_millis(1)));
}
}

@ -1,7 +1,7 @@
mod cli;
mod core;
pub use self::core::{data, net};
pub use self::core::{client::*, data, net, server::*};
use log::error;
/// Main entrypoint into the program

Loading…
Cancel
Save