Big refactor that is not finished

pull/38/head
Chip Senkbeil 3 years ago committed by Chip Senkbeil
parent 2b23cd379c
commit 1ca3cd7859
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -0,0 +1,51 @@
use std::ops::{Deref, DerefMut};
/// Wraps a string to provide some friendly read and write methods
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct StringBuf(String);
impl StringBuf {
pub fn new() -> Self {
Self(String::new())
}
/// Consumes data within the buffer that represent full lines (end with a newline) and returns
/// the string containing those lines.
///
/// The remaining buffer contains are returned as the second part of the tuple
pub fn into_full_lines(mut self) -> (Option<String>, StringBuf) {
match self.rfind('\n') {
Some(idx) => {
let remaining = self.0.split_off(idx + 1);
(Some(self.0), Self(remaining))
}
None => (None, self),
}
}
}
impl From<String> for StringBuf {
fn from(x: String) -> Self {
Self(x)
}
}
impl From<StringBuf> for String {
fn from(x: StringBuf) -> Self {
x.0
}
}
impl Deref for StringBuf {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for StringBuf {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

@ -0,0 +1,75 @@
use crate::core::net::TransportError;
/// Exit codes following https://www.freebsd.org/cgi/man.cgi?query=sysexits&sektion=3
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ExitCode {
/// EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI
Usage = 64,
/// EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data
/// is bad
DataErr = 65,
/// EX_NOINPUT (66) - being used when not getting expected data from launch
NoInput = 66,
/// EX_NOHOST (68) - being used when failed to resolve a host
NoHost = 68,
/// EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem
Unavailable = 69,
/// EX_SOFTWARE (70) - being used for internal errors that can occur like joining a task
Software = 70,
/// EX_OSERR (71) - being used when fork failed
OsErr = 71,
/// EX_IOERR (74) - being used as catchall for IO errors
IoError = 74,
/// EX_TEMPFAIL (75) - being used when we get a timeout
TempFail = 75,
/// EX_PROTOCOL (76) - being used as catchall for transport errors
Protocol = 76,
}
/// Represents an error that can be converted into an exit code
pub trait ExitCodeError: std::error::Error {
fn to_exit_code(&self) -> ExitCode;
fn to_i32(&self) -> i32 {
self.to_exit_code() as i32
}
}
impl ExitCodeError for std::io::Error {
fn to_exit_code(&self) -> ExitCode {
use std::io::ErrorKind;
match self.kind() {
ErrorKind::ConnectionAborted
| ErrorKind::ConnectionRefused
| ErrorKind::ConnectionReset
| ErrorKind::NotConnected => ExitCode::Unavailable,
ErrorKind::InvalidData => ExitCode::DataErr,
ErrorKind::TimedOut => ExitCode::TempFail,
_ => ExitCode::IoError,
}
}
}
impl ExitCodeError for TransportError {
fn to_exit_code(&self) -> ExitCode {
match self {
TransportError::IoError(x) => x.to_exit_code(),
_ => ExitCode::Protocol,
}
}
}
impl<T: ExitCodeError + 'static> From<T> for Box<dyn ExitCodeError> {
fn from(x: T) -> Self {
Box::new(x)
}
}

@ -1,2 +1,10 @@
pub mod opt;
mod buf;
mod exit;
mod opt;
mod output;
mod session;
mod subcommand;
pub use exit::{ExitCode, ExitCodeError};
pub use opt::*;
pub use output::ResponseOut;

@ -1,19 +1,19 @@
use crate::{
cli::subcommand,
cli::{subcommand, ExitCodeError},
core::{
constants::{
SERVER_CONN_MSG_CAPACITY_STR, SESSION_FILE_PATH_STR, SESSION_SOCKET_PATH_STR,
TIMEOUT_STR,
},
data::RequestData,
server::PortRange,
},
ExitCodeError,
};
use derive_more::{Display, Error, From, IsVariant};
use lazy_static::lazy_static;
use std::{
env,
net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr},
path::PathBuf,
str::FromStr,
time::Duration,
@ -438,67 +438,6 @@ impl LaunchSubcommand {
}
}
/// Represents some range of ports
#[derive(Clone, Debug, Display, PartialEq, Eq)]
#[display(
fmt = "{}{}",
start,
"end.as_ref().map(|end| format!(\"[:{}]\", end)).unwrap_or_default()"
)]
pub struct PortRange {
pub start: u16,
pub end: Option<u16>,
}
impl PortRange {
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
let mut socket_addrs = Vec::new();
let addr = addr.into();
for port in self.start..=self.end.unwrap_or(self.start) {
socket_addrs.push(SocketAddr::from((addr, port)));
}
socket_addrs
}
}
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
pub enum PortRangeParseError {
InvalidPort,
MissingPort,
}
impl FromStr for PortRange {
type Err = PortRangeParseError;
/// Parses PORT into single range or PORT1:PORTN into full range
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut tokens = s.trim().split(':');
let start = tokens
.next()
.ok_or(PortRangeParseError::MissingPort)?
.parse::<u16>()
.map_err(|_| PortRangeParseError::InvalidPort)?;
let end = if let Some(token) = tokens.next() {
Some(
token
.parse::<u16>()
.map_err(|_| PortRangeParseError::InvalidPort)?,
)
} else {
None
};
if tokens.next().is_some() {
return Err(PortRangeParseError::InvalidPort);
}
Ok(Self { start, end })
}
}
/// Represents subcommand to operate in listen mode for incoming requests
#[derive(Debug, StructOpt)]
pub struct ListenSubcommand {

@ -0,0 +1,178 @@
use crate::{
cli::Format,
core::data::{Error, Response, ResponseData},
};
use log::*;
use std::io;
/// Represents the output content and destination
pub enum ResponseOut {
Stdout(String),
StdoutLine(String),
Stderr(String),
StderrLine(String),
None,
}
impl ResponseOut {
/// Create a new output message for the given response based on the specified format
pub fn new(format: Format, res: Response) -> io::Result<ResponseOut> {
let payload_cnt = res.payload.len();
Ok(match format {
Format::Json => ResponseOut::StdoutLine(format!(
"{}",
serde_json::to_string(&res)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?
)),
// NOTE: For shell, we assume a singular entry in the response's payload
Format::Shell if payload_cnt != 1 => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Got {} entries in payload data, but shell expects exactly 1",
payload_cnt
),
))
}
Format::Shell => format_shell(res.payload.into_iter().next().unwrap()),
})
}
/// Consumes the output message, printing it based on its configuration
pub fn print(self) {
match self {
Self::Stdout(x) => {
// NOTE: Because we are not including a newline in the output,
// it is not guaranteed to be written out. In the case of
// LSP protocol, the JSON content is not followed by a
// newline and was not picked up when the response was
// sent back to the client; so, we need to manually flush
use std::io::Write;
print!("{}", x);
if let Err(x) = std::io::stdout().lock().flush() {
error!("Failed to flush stdout: {}", x);
}
}
Self::StdoutLine(x) => println!("{}", x),
Self::Stderr(x) => {
use std::io::Write;
eprint!("{}", x);
if let Err(x) = std::io::stderr().lock().flush() {
error!("Failed to flush stderr: {}", x);
}
}
Self::StderrLine(x) => eprintln!("{}", x),
Self::None => {}
}
}
}
fn format_shell(data: ResponseData) -> ResponseOut {
match data {
ResponseData::Ok => ResponseOut::None,
ResponseData::Error(Error { kind, description }) => {
ResponseOut::StderrLine(format!("Failed ({}): '{}'.", kind, description))
}
ResponseData::Blob { data } => {
ResponseOut::StdoutLine(String::from_utf8_lossy(&data).to_string())
}
ResponseData::Text { data } => ResponseOut::StdoutLine(data),
ResponseData::DirEntries { entries, .. } => ResponseOut::StdoutLine(format!(
"{}",
entries
.into_iter()
.map(|entry| {
format!(
"{}{}",
entry.path.as_os_str().to_string_lossy(),
if entry.file_type.is_dir() {
// NOTE: This can be different from the server if
// the server OS is unix and the client is
// not or vice versa; for now, this doesn't
// matter as we only support unix-based
// operating systems, but something to keep
// in mind
std::path::MAIN_SEPARATOR.to_string()
} else {
String::new()
},
)
})
.collect::<Vec<String>>()
.join("\n"),
)),
ResponseData::Exists(exists) => {
if exists {
ResponseOut::StdoutLine("Does exist.".to_string())
} else {
ResponseOut::StdoutLine("Does not exist.".to_string())
}
}
ResponseData::Metadata {
canonicalized_path,
file_type,
len,
readonly,
accessed,
created,
modified,
} => ResponseOut::StdoutLine(format!(
concat!(
"{}",
"Type: {}\n",
"Len: {}\n",
"Readonly: {}\n",
"Created: {}\n",
"Last Accessed: {}\n",
"Last Modified: {}",
),
canonicalized_path
.map(|p| format!("Canonicalized Path: {:?}\n", p))
.unwrap_or_default(),
file_type.as_ref(),
len,
readonly,
created.unwrap_or_default(),
accessed.unwrap_or_default(),
modified.unwrap_or_default(),
)),
ResponseData::ProcEntries { entries } => ResponseOut::StdoutLine(format!(
"{}",
entries
.into_iter()
.map(|entry| format!("{}: {} {}", entry.id, entry.cmd, entry.args.join(" ")))
.collect::<Vec<String>>()
.join("\n"),
)),
ResponseData::ProcStart { .. } => ResponseOut::None,
ResponseData::ProcStdout { data, .. } => ResponseOut::Stdout(data),
ResponseData::ProcStderr { data, .. } => ResponseOut::Stderr(data),
ResponseData::ProcDone { id, success, code } => {
if success {
ResponseOut::None
} else if let Some(code) = code {
ResponseOut::StderrLine(format!("Proc {} failed with code {}", id, code))
} else {
ResponseOut::StderrLine(format!("Proc {} failed", id))
}
}
ResponseData::SystemInfo {
family,
os,
arch,
current_dir,
main_separator,
} => ResponseOut::StdoutLine(format!(
concat!(
"Family: {:?}\n",
"Operating System: {:?}\n",
"Arch: {:?}\n",
"Cwd: {:?}\n",
"Path Sep: {:?}",
),
family, os, arch, current_dir, main_separator,
)),
}
}

@ -0,0 +1,155 @@
use crate::{
cli::{buf::StringBuf, Format, ResponseOut},
core::{
client::Session,
constants::MAX_PIPE_CHUNK_SIZE,
data::{Request, Response},
net::DataStream,
},
};
use log::*;
use std::{
io::{self, BufReader, Read},
sync::Arc,
thread,
};
use tokio::sync::{mpsc, watch};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
/// Represents a wrapper around a session that provides CLI functionality such as reading from
/// stdin and piping results back out to stdout
pub struct CliSession<T>
where
T: DataStream,
{
inner: Session<T>,
}
impl<T> CliSession<T>
where
T: DataStream,
{
pub fn new(inner: Session<T>) -> Self {
Self { inner }
}
}
// TODO TODO TODO:
//
// 1. Change watch to broadcast if going to use in both loops, otherwise just make
// it an mpsc otherwise
// 2. Need to provide outgoing requests function with logic from inner.rs to create a request
// based on the format (json or shell), where json uses serde_json::from_str and shell
// uses Request::new(tenant.as_str(), vec![RequestData::from_iter_safe(...)])
// 3. Need to add a wait method to block on the running tasks
// 4. Need to add an abort method to abort the tasks
// 5. Is there any way to deal with the blocking thread for stdin to kill it? This isn't a big
// deal as the shutdown would only be happening on client termination anyway, but still...
/// Helper function that loops, processing incoming responses not tied to a request to be sent out
/// over stdout/stderr
async fn process_incoming_responses(
mut stream: BroadcastStream<Response>,
format: Format,
mut exit: watch::Receiver<bool>,
) -> io::Result<()> {
loop {
tokio::select! {
res = stream.next() => {
match res {
Some(Ok(res)) => ResponseOut::new(format, res)?.print(),
Some(Err(x)) => return Err(io::Error::new(io::ErrorKind::BrokenPipe, x)),
None => return Ok(()),
}
}
_ = exit.changed() => {
return Ok(());
}
}
}
}
/// Helper function that loops, processing outgoing requests created from stdin, and printing out
/// responses
async fn process_outgoing_requests<T, F>(
mut session: Session<T>,
mut stdin_rx: mpsc::Receiver<String>,
format: Format,
map_line: F,
) where
T: DataStream,
F: Fn(&str) -> io::Result<Request>,
{
let mut buf = StringBuf::new();
while let Some(data) = stdin_rx.recv().await {
// Update our buffer with the new data and split it into concrete lines and remainder
buf.push_str(&data);
let (lines, new_buf) = buf.into_full_lines();
buf = new_buf;
// For each complete line, parse into a request
if let Some(lines) = lines {
for line in lines.lines() {
trace!("Processing line: {:?}", line);
if line.trim().is_empty() {
continue;
}
match map_line(line) {
Ok(req) => match session.send(req).await {
Ok(res) => match ResponseOut::new(format, res) {
Ok(out) => out.print(),
Err(x) => error!("Failed to format response: {}", x),
},
Err(x) => {
error!("Failed to send request: {}", x)
}
},
Err(x) => {
error!("Failed to parse line: {}", x);
}
}
}
}
}
}
/// Creates a new thread that performs stdin reads in a blocking fashion, returning
/// a handle to the thread and a receiver that will be sent input as it becomes available
fn spawn_stdin_reader() -> (thread::JoinHandle<()>, mpsc::Receiver<String>) {
let (tx, rx) = mpsc::channel(1);
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
// pass the results to a separate async handler to forward to the remote process
let handle = thread::spawn(move || {
let mut stdin = BufReader::new(io::stdin());
// Maximum chunk that we expect to read at any one time
let mut buf = [0; MAX_PIPE_CHUNK_SIZE];
loop {
match stdin.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
match String::from_utf8(buf[..n].to_vec()) {
Ok(text) => {
if let Err(x) = tx.blocking_send(text) {
error!(
"Failed to pass along stdin to be sent to remote process: {}",
x
);
}
}
Err(x) => {
error!("Input over stdin is invalid: {}", x);
}
}
thread::yield_now();
}
}
}
});
(handle, rx)
}

@ -180,211 +180,3 @@ where
Ok(())
}
fn spawn_stdin_reader() -> mpsc::Receiver<String> {
let (tx, rx) = mpsc::channel(1);
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
// pass the results to a separate async handler to forward to the remote process
std::thread::spawn(move || {
use std::io::{self, BufReader, Read};
let mut stdin = BufReader::new(io::stdin());
// Maximum chunk that we expect to read at any one time
let mut buf = [0; MAX_PIPE_CHUNK_SIZE];
loop {
match stdin.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
match String::from_utf8(buf[..n].to_vec()) {
Ok(text) => {
if let Err(x) = tx.blocking_send(text) {
error!(
"Failed to pass along stdin to be sent to remote process: {}",
x
);
}
}
Err(x) => {
error!("Input over stdin is invalid: {}", x);
}
}
std::thread::yield_now();
}
}
}
});
rx
}
/// Represents the output content and destination
pub enum ResponseOut {
Stdout(String),
StdoutLine(String),
Stderr(String),
StderrLine(String),
None,
}
impl ResponseOut {
pub fn print(self) {
match self {
Self::Stdout(x) => {
// NOTE: Because we are not including a newline in the output,
// it is not guaranteed to be written out. In the case of
// LSP protocol, the JSON content is not followed by a
// newline and was not picked up when the response was
// sent back to the client; so, we need to manually flush
use std::io::Write;
print!("{}", x);
if let Err(x) = std::io::stdout().lock().flush() {
error!("Failed to flush stdout: {}", x);
}
}
Self::StdoutLine(x) => println!("{}", x),
Self::Stderr(x) => {
use std::io::Write;
eprint!("{}", x);
if let Err(x) = std::io::stderr().lock().flush() {
error!("Failed to flush stderr: {}", x);
}
}
Self::StderrLine(x) => eprintln!("{}", x),
Self::None => {}
}
}
}
pub fn format_response(format: Format, res: Response) -> io::Result<ResponseOut> {
let payload_cnt = res.payload.len();
Ok(match format {
Format::Json => ResponseOut::StdoutLine(format!(
"{}",
serde_json::to_string(&res)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?
)),
// NOTE: For shell, we assume a singular entry in the response's payload
Format::Shell if payload_cnt != 1 => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Got {} entries in payload data, but shell expects exactly 1",
payload_cnt
),
))
}
Format::Shell => format_shell(res.payload.into_iter().next().unwrap()),
})
}
fn format_shell(data: ResponseData) -> ResponseOut {
match data {
ResponseData::Ok => ResponseOut::None,
ResponseData::Error(Error { kind, description }) => {
ResponseOut::StderrLine(format!("Failed ({}): '{}'.", kind, description))
}
ResponseData::Blob { data } => {
ResponseOut::StdoutLine(String::from_utf8_lossy(&data).to_string())
}
ResponseData::Text { data } => ResponseOut::StdoutLine(data),
ResponseData::DirEntries { entries, .. } => ResponseOut::StdoutLine(format!(
"{}",
entries
.into_iter()
.map(|entry| {
format!(
"{}{}",
entry.path.as_os_str().to_string_lossy(),
if entry.file_type.is_dir() {
// NOTE: This can be different from the server if
// the server OS is unix and the client is
// not or vice versa; for now, this doesn't
// matter as we only support unix-based
// operating systems, but something to keep
// in mind
std::path::MAIN_SEPARATOR.to_string()
} else {
String::new()
},
)
})
.collect::<Vec<String>>()
.join("\n"),
)),
ResponseData::Exists(exists) => {
if exists {
ResponseOut::StdoutLine("Does exist.".to_string())
} else {
ResponseOut::StdoutLine("Does not exist.".to_string())
}
}
ResponseData::Metadata {
canonicalized_path,
file_type,
len,
readonly,
accessed,
created,
modified,
} => ResponseOut::StdoutLine(format!(
concat!(
"{}",
"Type: {}\n",
"Len: {}\n",
"Readonly: {}\n",
"Created: {}\n",
"Last Accessed: {}\n",
"Last Modified: {}",
),
canonicalized_path
.map(|p| format!("Canonicalized Path: {:?}\n", p))
.unwrap_or_default(),
file_type.as_ref(),
len,
readonly,
created.unwrap_or_default(),
accessed.unwrap_or_default(),
modified.unwrap_or_default(),
)),
ResponseData::ProcEntries { entries } => ResponseOut::StdoutLine(format!(
"{}",
entries
.into_iter()
.map(|entry| format!("{}: {} {}", entry.id, entry.cmd, entry.args.join(" ")))
.collect::<Vec<String>>()
.join("\n"),
)),
ResponseData::ProcStart { .. } => ResponseOut::None,
ResponseData::ProcStdout { data, .. } => ResponseOut::Stdout(data),
ResponseData::ProcStderr { data, .. } => ResponseOut::Stderr(data),
ResponseData::ProcDone { id, success, code } => {
if success {
ResponseOut::None
} else if let Some(code) = code {
ResponseOut::StderrLine(format!("Proc {} failed with code {}", id, code))
} else {
ResponseOut::StderrLine(format!("Proc {} failed", id))
}
}
ResponseData::SystemInfo {
family,
os,
arch,
current_dir,
main_separator,
} => ResponseOut::StdoutLine(format!(
concat!(
"Family: {:?}\n",
"Operating System: {:?}\n",
"Arch: {:?}\n",
"Cwd: {:?}\n",
"Path Sep: {:?}",
),
family, os, arch, current_dir, main_separator,
)),
}
}

@ -1,13 +1,13 @@
use crate::{
cli::opt::{ActionSubcommand, CommonOpt, Format, SessionInput},
cli::{
opt::{ActionSubcommand, CommonOpt, Format, SessionInput},
ExitCode, ExitCodeError,
},
core::{
client::{LspData, Session, SessionInfo, SessionInfoFile},
data::{Request, RequestData, ResponseData},
lsp::LspData,
net::{Client, DataStream, TransportError},
session::{Session, SessionFile},
utils,
net::{DataStream, TransportError},
},
ExitCode, ExitCodeError,
};
use derive_more::{Display, Error, From};
use log::*;
@ -47,7 +47,7 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
SessionInput::Environment => {
start(
cmd,
Client::tcp_connect_timeout(Session::from_environment()?, timeout).await?,
Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?,
timeout,
None,
)
@ -57,8 +57,11 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let path = cmd.session_data.session_file.clone();
start(
cmd,
Client::tcp_connect_timeout(SessionFile::load_from(path).await?.into(), timeout)
.await?,
Session::tcp_connect_timeout(
SessionInfoFile::load_from(path).await?.into(),
timeout,
)
.await?,
timeout,
None,
)
@ -67,7 +70,7 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
SessionInput::Pipe => {
start(
cmd,
Client::tcp_connect_timeout(Session::from_stdin()?, timeout).await?,
Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?,
timeout,
None,
)
@ -76,10 +79,10 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
SessionInput::Lsp => {
let mut data =
LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?;
let session = data.take_session().map_err(io::Error::from)?;
let info = data.take_session_info().map_err(io::Error::from)?;
start(
cmd,
Client::tcp_connect_timeout(session, timeout).await?,
Session::tcp_connect_timeout(info, timeout).await?,
timeout,
Some(data),
)
@ -90,7 +93,7 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let path = cmd.session_data.session_socket.clone();
start(
cmd,
Client::unix_connect_timeout(path, None, timeout).await?,
Session::unix_connect_timeout(path, None, timeout).await?,
timeout,
None,
)
@ -101,17 +104,59 @@ async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
async fn start<T>(
cmd: ActionSubcommand,
mut client: Client<T>,
mut session: Session<T>,
timeout: Duration,
lsp_data: Option<LspData>,
) -> Result<(), Error>
where
T: DataStream + 'static,
{
if !cmd.interactive && cmd.operation.is_none() {
return Err(Error::MissingOperation);
// TODO: Because lsp is being handled in a separate action, we should fail if we get
// a session type of lsp for a regular action
match (cmd.interactive, cmd.operation) {
// ProcRun request is specially handled and we ignore interactive as
// the stdin will be used for sending ProcStdin to remote process
(_, Some(RequestData::ProcRun { cmd, args })) => {}
// All other requests without interactive are oneoffs
(false, Some(req)) => {
let res = session.send_timeout(req, timeout).await?;
}
// Interactive mode will send an optional first request and then continue
// to read stdin to send more
(true, maybe_req) => {}
// Not interactive and no operation given
(false, None) => Err(Error::MissingOperation),
}
// 1. Determine what type of engagement we're doing
// a. Oneoff connection, request, response
// b. ProcRun where we take over stdin, stdout, stderr to provide a remote
// process experience
// c. Lsp where we do the ProcRun stuff, but translate stdin before sending and
// stdout before outputting
// d. Interactive program
//
// 2. If we have a queued up operation, we need to perform it
// a. For oneoff, this is the request of the oneoff
// b. For Procrun, this is the request that starts the process
// c. For Lsp, this is the request that starts the process
// d. For interactive, this is an optional first request
//
// 3. If we are using LSP session mode, then we want to send the
// ProcStdin request after our optional queued up operation
// a. For oneoff, this doesn't make sense and we should fail
// b. For ProcRun, we do this after the ProcStart
// c. For Lsp, we do this after the ProcStart
// d. For interactive, this doesn't make sense as we only support
// JSON and shell command input, not LSP input, so this would
// fail and we should fail early
//
// ** LSP would be its own action, which means we want to abstract the logic that feeds
// into this start method such that it can also be used with lsp action
// Make up a tenant name
let tenant = utils::new_tenant();
@ -127,7 +172,7 @@ where
is_proc_req = req.payload.iter().any(|x| x.is_proc_run());
debug!("Client sending request: {:?}", req);
let res = client.send_timeout(req, timeout).await?;
let res = session.send_timeout(req, timeout).await?;
// Store the spawned process id for using in sending stdin (if we spawned a proc)
// NOTE: We can assume that there is a single payload entry in response to our single
@ -144,7 +189,7 @@ where
// TODO: Do we need to do this somewhere else to apply to all possible ways an LSP
// could be started?
if let Some(data) = lsp_data {
client
session
.fire_timeout(
Request::new(
tenant.as_str(),

@ -1,5 +1,8 @@
use crate::{
cli::opt::{CommonOpt, Format, LaunchSubcommand, SessionOutput},
cli::{
opt::{CommonOpt, Format, LaunchSubcommand, SessionOutput},
ExitCode, ExitCodeError,
},
core::{
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
@ -7,7 +10,6 @@ use crate::{
session::{Session, SessionFile},
utils,
},
ExitCode, ExitCodeError,
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};

@ -0,0 +1,94 @@
use crate::{
cli::{
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
ExitCode, ExitCodeError,
},
core::server::DistantServer,
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};
use log::*;
use tokio::{io, task::JoinError};
#[derive(Debug, Display, Error, From)]
pub enum Error {
ConvertToIpAddrError(ConvertToIpAddrError),
ForkError,
IoError(io::Error),
JoinError(JoinError),
}
impl ExitCodeError for Error {
fn to_exit_code(&self) -> ExitCode {
match self {
Self::ConvertToIpAddrError(_) => ExitCode::NoHost,
Self::ForkError => ExitCode::OsErr,
Self::IoError(x) => x.to_exit_code(),
Self::JoinError(_) => ExitCode::Software,
}
}
}
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
if cmd.daemon {
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
match daemon(false, true) {
Ok(Fork::Child) => {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, true).await })?;
}
Ok(Fork::Parent(pid)) => {
info!("[distant detached, pid = {}]", pid);
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
Err(_) => return Err(Error::ForkError),
}
} else {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, false).await })?;
}
Ok(())
}
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
let socket_addrs = cmd.port.make_socket_addrs(addr);
let shutdown_after = cmd.to_shutdown_after_duration();
// If specified, change the current working directory of this program
if let Some(path) = cmd.current_dir.as_ref() {
debug!("Setting current directory to {:?}", path);
std::env::set_current_dir(path)?;
}
// Bind & start our server
let server = DistantServer::bind(
addr,
cmd.port,
cmd.to_shutdown_after_duration(),
cmd.max_msg_capacity as usize,
)
.await?;
// Print information about port, key, etc.
println!(
"DISTANT DATA -- {} {}",
server.port(),
server.to_unprotected_hex_auth_key()
);
// For the child, we want to fully disconnect it from pipes, which we do now
if is_forked {
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
// Let our server run to completion
server.wait().await?;
Ok(())
}

@ -1,207 +0,0 @@
use crate::{
cli::opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
core::{
data::{Request, Response},
net::{Transport, TransportReadHalf, TransportWriteHalf},
session::Session,
state::ServerState,
utils,
},
ExitCode, ExitCodeError,
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};
use log::*;
use orion::aead::SecretKey;
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io,
net::{tcp, TcpListener},
runtime::Handle,
sync::{mpsc, Mutex},
};
mod handler;
#[derive(Debug, Display, Error, From)]
pub enum Error {
ConvertToIpAddrError(ConvertToIpAddrError),
ForkError,
IoError(io::Error),
}
impl ExitCodeError for Error {
fn to_exit_code(&self) -> ExitCode {
match self {
Self::ConvertToIpAddrError(_) => ExitCode::NoHost,
Self::ForkError => ExitCode::OsErr,
Self::IoError(x) => x.to_exit_code(),
}
}
}
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
if cmd.daemon {
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
match daemon(false, true) {
Ok(Fork::Child) => {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, true).await })?;
}
Ok(Fork::Parent(pid)) => {
info!("[distant detached, pid = {}]", pid);
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
Err(_) => return Err(Error::ForkError),
}
} else {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async { run_async(cmd, opt, false).await })?;
}
Ok(())
}
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
let socket_addrs = cmd.port.make_socket_addrs(addr);
let shutdown_after = cmd.to_shutdown_after_duration();
// If specified, change the current working directory of this program
if let Some(path) = cmd.current_dir.as_ref() {
debug!("Setting current directory to {:?}", path);
std::env::set_current_dir(path)?;
}
debug!("Binding to {} in range {}", addr, cmd.port);
let listener = TcpListener::bind(socket_addrs.as_slice()).await?;
let port = listener.local_addr()?.port();
debug!("Bound to port: {}", port);
// Print information about port, key, etc.
let key = {
let session = Session {
host: "--".to_string(),
port,
auth_key: SecretKey::default(),
};
println!("{}", session.to_unprotected_string());
Arc::new(session.into_auth_key())
};
// For the child, we want to fully disconnect it from pipes, which we do now
if is_forked {
if let Err(_) = fork::close_fd() {
return Err(Error::ForkError);
}
}
// Build our state for the server
let state: Arc<Mutex<ServerState<SocketAddr>>> = Arc::new(Mutex::new(ServerState::default()));
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
// Wait for a client connection, then spawn a new task to handle
// receiving data from the client
loop {
tokio::select! {
result = listener.accept() => {match result {
Ok((client, addr)) => {
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(client, Some(Arc::clone(&key))).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {}> Failed handshake: {}", addr, x);
continue;
}
};
// Split the transport into read and write halves so we can handle input
// and output concurrently
let (t_read, t_write) = transport.into_split();
let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize);
let ct_2 = Arc::clone(&ct);
// Spawn a new task that loops to handle requests from the client
tokio::spawn({
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
ct_2.lock().await.increment();
f.await;
state.lock().await.cleanup_client(addr).await;
ct_2.lock().await.decrement();
}
});
// Spawn a new task that loops to handle responses to the client
tokio::spawn(async move { response_loop(addr, t_write, rx).await });
}
Err(x) => {
error!("Listener failed: {}", x);
break;
}
}}
_ = notify.notified() => {
warn!("Reached shutdown timeout, so terminating");
break;
}
}
}
Ok(())
}
/// Repeatedly reads in new requests, processes them, and sends their responses to the
/// response loop
async fn request_loop(
addr: SocketAddr,
state: Arc<Mutex<ServerState<SocketAddr>>>,
mut transport: TransportReadHalf<tcp::OwnedReadHalf>,
tx: mpsc::Sender<Response>,
) {
loop {
match transport.receive::<Request>().await {
Ok(Some(req)) => {
debug!(
"<Client @ {}> Received request of type{} {}",
addr,
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = handler::process(addr, Arc::clone(&state), req, tx.clone()).await {
error!("<Client @ {}> {}", addr, x);
break;
}
}
Ok(None) => {
info!("<Client @ {}> Closed connection", addr);
break;
}
Err(x) => {
error!("<Client @ {}> {}", addr, x);
break;
}
}
}
}
/// Repeatedly sends responses out over the wire
async fn response_loop(
addr: SocketAddr,
mut transport: TransportWriteHalf<tcp::OwnedWriteHalf>,
mut rx: mpsc::Receiver<Response>,
) {
while let Some(res) = rx.recv().await {
if let Err(x) = transport.send(res).await {
error!("<Client @ {}> {}", addr, x);
break;
}
}
}

@ -1,4 +1,4 @@
use crate::core::session::{Session, SessionParseError};
use crate::core::client::{SessionInfo, SessionInfoParseError};
use derive_more::{Display, Error, From};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
@ -11,24 +11,24 @@ use std::{
};
#[derive(Copy, Clone, Debug, PartialEq, Eq, Display, Error, From)]
pub enum LspSessionError {
pub enum LspSessionInfoError {
/// Encountered when attempting to create a session from a request that is not initialize
NotInitializeRequest,
/// Encountered if missing session parameters within an initialize request
MissingSessionParams,
MissingSessionInfoParams,
/// Encountered if session parameters are not expected types
InvalidSessionParams,
InvalidSessionInfoParams,
/// Encountered when failing to parse session
SessionParseError(SessionParseError),
SessionInfoParseError(SessionInfoParseError),
}
impl From<LspSessionError> for io::Error {
fn from(x: LspSessionError) -> Self {
impl From<LspSessionInfoError> for io::Error {
fn from(x: LspSessionInfoError) -> Self {
match x {
LspSessionError::SessionParseError(x) => x.into(),
LspSessionInfoError::SessionInfoParseError(x) => x.into(),
x => io::Error::new(io::ErrorKind::InvalidData, x),
}
}
@ -38,19 +38,19 @@ impl From<LspSessionError> for io::Error {
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LspData {
/// Header-portion of some data related to LSP
header: LspDataHeader,
header: LspHeader,
/// Content-portion of some data related to LSP
content: LspDataContent,
content: LspContent,
}
#[derive(Debug, Display, Error, From)]
pub enum LspDataParseError {
/// When the received content is malformed
BadContent(LspDataContentParseError),
BadContent(LspContentParseError),
/// When the received header is malformed
BadHeader(LspDataHeaderParseError),
BadHeader(LspHeaderParseError),
/// When a header line is not terminated in \r\n
BadHeaderTermination,
@ -83,20 +83,29 @@ impl From<LspDataParseError> for io::Error {
impl LspData {
/// Returns a reference to the header part
pub fn header(&self) -> &LspDataHeader {
pub fn header(&self) -> &LspHeader {
&self.header
}
/// Returns a mutable reference to the header part
pub fn mut_header(&mut self) -> &mut LspHeader {
&mut self.header
}
/// Returns a reference to the content part
pub fn content(&self) -> &LspDataContent {
pub fn content(&self) -> &LspContent {
&self.content
}
/// Creates a session by inspecting the content for session parameters, removing the session
/// parameters from the content. Will also adjust the content length header to match the
/// new size of the content.
pub fn take_session(&mut self) -> Result<Session, LspSessionError> {
match self.content.take_session() {
/// Returns a mutable reference to the content part
pub fn mut_content(&mut self) -> &mut LspContent {
&mut self.content
}
/// Creates a session's info by inspecting the content for session parameters, removing the
/// session parameters from the content. Will also adjust the content length header to match
/// the new size of the content.
pub fn take_session_info(&mut self) -> Result<SessionInfo, LspSessionInfoError> {
match self.content.take_session_info() {
Ok(session) => {
self.header.content_length = self.content.to_string().len();
Ok(session)
@ -107,7 +116,7 @@ impl LspData {
/// Attempts to read incoming lsp data from a buffered reader.
///
/// Note that this is **blocking** while it waits on the header information!
/// Note that this is **blocking** while it waits on the header information (or EOF)!
///
/// ```text
/// Content-Length: ...\r\n
@ -147,7 +156,7 @@ impl LspData {
}
// Parse the header content so we know how much more to read
let header = buf.parse::<LspDataHeader>()?;
let header = buf.parse::<LspHeader>()?;
// Read remaining content
let content = {
@ -159,7 +168,7 @@ impl LspData {
LspDataParseError::IoError(x)
}
})?;
String::from_utf8(buf)?.parse::<LspDataContent>()?
String::from_utf8(buf)?.parse::<LspContent>()?
};
Ok(Self { header, content })
@ -206,7 +215,7 @@ impl FromStr for LspData {
/// Represents the header for LSP data
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LspDataHeader {
pub struct LspHeader {
/// Length of content part in bytes
pub content_length: usize,
@ -215,7 +224,7 @@ pub struct LspDataHeader {
pub content_type: Option<String>,
}
impl fmt::Display for LspDataHeader {
impl fmt::Display for LspHeader {
/// Outputs header in form
///
/// ```text
@ -235,20 +244,20 @@ impl fmt::Display for LspDataHeader {
}
#[derive(Clone, Debug, PartialEq, Eq, Display, Error, From)]
pub enum LspDataHeaderParseError {
pub enum LspHeaderParseError {
MissingContentLength,
InvalidContentLength(std::num::ParseIntError),
BadHeaderField,
}
impl From<LspDataHeaderParseError> for io::Error {
fn from(x: LspDataHeaderParseError) -> Self {
impl From<LspHeaderParseError> for io::Error {
fn from(x: LspHeaderParseError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, x)
}
}
impl FromStr for LspDataHeader {
type Err = LspDataHeaderParseError;
impl FromStr for LspHeader {
type Err = LspHeaderParseError;
/// Parses headers in the form of
///
@ -267,10 +276,10 @@ impl FromStr for LspDataHeader {
match name {
"Content-Length" => content_length = Some(value.trim().parse()?),
"Content-Type" => content_type = Some(value.trim().to_string()),
_ => return Err(LspDataHeaderParseError::BadHeaderField),
_ => return Err(LspHeaderParseError::BadHeaderField),
}
} else {
return Err(LspDataHeaderParseError::BadHeaderField);
return Err(LspHeaderParseError::BadHeaderField);
}
}
@ -279,36 +288,102 @@ impl FromStr for LspDataHeader {
content_length,
content_type,
}),
None => Err(LspDataHeaderParseError::MissingContentLength),
None => Err(LspHeaderParseError::MissingContentLength),
}
}
}
/// Represents the content for LSP data
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LspDataContent(Map<String, Value>);
pub struct LspContent(Map<String, Value>);
fn for_each_mut_string<F1, F2>(value: &mut Value, check: F1, mutate: F2)
where
F1: Fn(&String) -> bool,
F2: FnMut(&mut String),
{
match value {
Value::Object(obj) => {
// Mutate values
obj.values_mut()
.for_each(|v| for_each_mut_string(v, check, mutate));
// Mutate keys if necessary
for key in obj.keys() {
if check(key) {
if let Some((key, value)) = obj.remove_entry(key) {
mutate(&mut key);
obj.insert(key, value);
}
}
}
}
Value::Array(items) => items
.iter_mut()
.for_each(|v| for_each_mut_string(v, check, mutate)),
Value::String(s) => mutate(s),
_ => {}
}
}
fn swap_prefix(obj: &mut Map<String, Value>, old: &str, new: &str) {
let check = |s: &String| s.starts_with(old);
let mutate = |s: &mut String| {
if let Some(pos) = s.find(old) {
s.replace_range(pos..old.len(), new);
}
};
// Mutate values
obj.values_mut()
.for_each(|v| for_each_mut_string(v, check, mutate));
// Mutate keys if necessary
for key in obj.keys() {
if check(key) {
if let Some((key, value)) = obj.remove_entry(key) {
mutate(&mut key);
obj.insert(key, value);
}
}
}
}
impl LspContent {
/// Converts all URIs with `file://` as the scheme to `distant://` instead
pub fn convert_local_scheme_to_distant(&mut self) {
swap_prefix(&mut self.0, "file:", "distant:");
}
impl LspDataContent {
/// Creates a session by inspecting the content for session parameters, removing the session
/// parameters from the content
pub fn take_session(&mut self) -> Result<Session, LspSessionError> {
/// Converts all URIs with `distant://` as the scheme to `file://` instead
pub fn convert_distant_scheme_to_local(&mut self) {
swap_prefix(&mut self.0, "distant:", "file:");
}
/// Creates a session's info by inspecting the content for session parameters, removing the
/// session parameters from the content
pub fn take_session_info(&mut self) -> Result<SessionInfo, LspSessionInfoError> {
// Verify that we're dealing with an initialize request
match self.0.get("method") {
Some(value) if value == "initialize" => {}
_ => return Err(LspSessionError::NotInitializeRequest),
_ => return Err(LspSessionInfoError::NotInitializeRequest),
}
// Attempt to grab the distant initialization options
match self.strip_session_params() {
Some((Some(host), Some(port), Some(auth_key))) => {
let host = host.as_str().ok_or(LspSessionError::InvalidSessionParams)?;
let port = port.as_u64().ok_or(LspSessionError::InvalidSessionParams)?;
let host = host
.as_str()
.ok_or(LspSessionInfoError::InvalidSessionInfoParams)?;
let port = port
.as_u64()
.ok_or(LspSessionInfoError::InvalidSessionInfoParams)?;
let auth_key = auth_key
.as_str()
.ok_or(LspSessionError::InvalidSessionParams)?;
.ok_or(LspSessionInfoError::InvalidSessionInfoParams)?;
Ok(format!("DISTANT DATA {} {} {}", host, port, auth_key).parse()?)
}
_ => Err(LspSessionError::MissingSessionParams),
_ => Err(LspSessionInfoError::MissingSessionInfoParams),
}
}
@ -343,13 +418,13 @@ impl LspDataContent {
}
}
impl AsRef<Map<String, Value>> for LspDataContent {
impl AsRef<Map<String, Value>> for LspContent {
fn as_ref(&self) -> &Map<String, Value> {
&self.0
}
}
impl Deref for LspDataContent {
impl Deref for LspContent {
type Target = Map<String, Value>;
fn deref(&self) -> &Self::Target {
@ -357,13 +432,13 @@ impl Deref for LspDataContent {
}
}
impl DerefMut for LspDataContent {
impl DerefMut for LspContent {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl fmt::Display for LspDataContent {
impl fmt::Display for LspContent {
/// Outputs content in JSON form
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
@ -375,16 +450,16 @@ impl fmt::Display for LspDataContent {
}
#[derive(Debug, Display, Error, From)]
pub struct LspDataContentParseError(serde_json::Error);
pub struct LspContentParseError(serde_json::Error);
impl From<LspDataContentParseError> for io::Error {
fn from(x: LspDataContentParseError) -> Self {
impl From<LspContentParseError> for io::Error {
fn from(x: LspContentParseError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, x)
}
}
impl FromStr for LspDataContent {
type Err = LspDataContentParseError;
impl FromStr for LspContent {
type Err = LspContentParseError;
/// Parses content in JSON form
fn from_str(s: &str) -> Result<Self, Self::Err> {
@ -409,11 +484,11 @@ mod tests {
#[test]
fn data_display_should_output_header_and_content() {
let data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({"hello": "world"})),
content: LspContent(make_obj!({"hello": "world"})),
};
let output = data.to_string();
@ -451,6 +526,10 @@ mod tests {
#[test]
fn data_from_buf_reader_should_fail_if_reach_eof_before_received_full_data() {
// No line termination
let err = LspData::from_buf_reader(&mut io::Cursor::new("Content-Length: 22")).unwrap_err();
assert!(matches!(err, LspDataParseError::UnexpectedEof), "{:?}", err);
// Header doesn't finish
let err = LspData::from_buf_reader(&mut io::Cursor::new(concat!(
"Content-Length: 22\r\n",
@ -534,13 +613,13 @@ mod tests {
}
#[test]
fn data_take_session_should_succeed_if_valid_session_found_in_params() {
fn data_take_session_info_should_succeed_if_valid_session_found_in_params() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -554,10 +633,10 @@ mod tests {
})),
};
let session = data.take_session().unwrap();
let info = data.take_session_info().unwrap();
assert_eq!(
session,
Session {
info,
SessionInfo {
host: String::from("some.host"),
port: 22,
auth_key: SecretKey::from_slice(&hex::decode(b"abc123").unwrap()).unwrap(),
@ -566,13 +645,13 @@ mod tests {
}
#[test]
fn data_take_session_should_remove_session_parameters_if_successful() {
fn data_take_session_info_should_remove_session_parameters_if_successful() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -586,7 +665,7 @@ mod tests {
})),
};
let _ = data.take_session().unwrap();
let _ = data.take_session_info().unwrap();
assert_eq!(
data.content.as_ref(),
&make_obj!({
@ -599,13 +678,13 @@ mod tests {
}
#[test]
fn data_take_session_should_adjust_content_length_based_on_new_content_byte_length() {
fn data_take_session_info_should_adjust_content_length_based_on_new_content_byte_length() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -619,18 +698,18 @@ mod tests {
})),
};
let _ = data.take_session().unwrap();
let _ = data.take_session_info().unwrap();
assert_eq!(data.header.content_length, data.content.to_string().len());
}
#[test]
fn data_take_session_should_fail_if_path_incomplete_to_session_params() {
fn data_take_session_info_should_fail_if_path_incomplete_to_session_params() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {}
@ -638,22 +717,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::MissingSessionParams),
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_missing_host_param() {
fn data_take_session_info_should_fail_if_missing_host_param() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -666,22 +745,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::MissingSessionParams),
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_host_param_is_invalid() {
fn data_take_session_info_should_fail_if_host_param_is_invalid() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -695,22 +774,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::InvalidSessionParams),
matches!(err, LspSessionInfoError::InvalidSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_missing_port_param() {
fn data_take_session_info_should_fail_if_missing_port_param() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -723,22 +802,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::MissingSessionParams),
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_port_param_is_invalid() {
fn data_take_session_info_should_fail_if_port_param_is_invalid() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -752,22 +831,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::InvalidSessionParams),
matches!(err, LspSessionInfoError::InvalidSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_missing_auth_key_param() {
fn data_take_session_info_should_fail_if_missing_auth_key_param() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -780,22 +859,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::MissingSessionParams),
matches!(err, LspSessionInfoError::MissingSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_auth_key_param_is_invalid() {
fn data_take_session_info_should_fail_if_auth_key_param_is_invalid() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "initialize",
"params": {
"initializationOptions": {
@ -809,22 +888,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::InvalidSessionParams),
matches!(err, LspSessionInfoError::InvalidSessionInfoParams),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_missing_method_field() {
fn data_take_session_info_should_fail_if_missing_method_field() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"params": {
"initializationOptions": {
"distant": {
@ -837,22 +916,22 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::NotInitializeRequest),
matches!(err, LspSessionInfoError::NotInitializeRequest),
"{:?}",
err
);
}
#[test]
fn data_take_session_should_fail_if_method_field_is_not_initialize() {
fn data_take_session_info_should_fail_if_method_field_is_not_initialize() {
let mut data = LspData {
header: LspDataHeader {
header: LspHeader {
content_length: 123456,
content_type: Some(String::from("some content type")),
},
content: LspDataContent(make_obj!({
content: LspContent(make_obj!({
"method": "not initialize",
"params": {
"initializationOptions": {
@ -866,9 +945,9 @@ mod tests {
})),
};
let err = data.take_session().unwrap_err();
let err = data.take_session_info().unwrap_err();
assert!(
matches!(err, LspSessionError::NotInitializeRequest),
matches!(err, LspSessionInfoError::NotInitializeRequest),
"{:?}",
err
);
@ -877,10 +956,10 @@ mod tests {
#[test]
fn header_parse_should_fail_if_missing_content_length() {
let err = "Content-Type: some type\r\n\r\n"
.parse::<LspDataHeader>()
.parse::<LspHeader>()
.unwrap_err();
assert!(
matches!(err, LspDataHeaderParseError::MissingContentLength),
matches!(err, LspHeaderParseError::MissingContentLength),
"{:?}",
err
);
@ -889,10 +968,10 @@ mod tests {
#[test]
fn header_parse_should_fail_if_content_length_invalid() {
let err = "Content-Length: -1\r\n\r\n"
.parse::<LspDataHeader>()
.parse::<LspHeader>()
.unwrap_err();
assert!(
matches!(err, LspDataHeaderParseError::InvalidContentLength(_)),
matches!(err, LspHeaderParseError::InvalidContentLength(_)),
"{:?}",
err
);
@ -901,10 +980,10 @@ mod tests {
#[test]
fn header_parse_should_fail_if_receive_an_unexpected_header_field() {
let err = "Content-Length: 123\r\nUnknown-Field: abc\r\n\r\n"
.parse::<LspDataHeader>()
.parse::<LspHeader>()
.unwrap_err();
assert!(
matches!(err, LspDataHeaderParseError::BadHeaderField),
matches!(err, LspHeaderParseError::BadHeaderField),
"{:?}",
err
);
@ -912,9 +991,7 @@ mod tests {
#[test]
fn header_parse_should_succeed_if_given_valid_content_length() {
let header = "Content-Length: 123\r\n\r\n"
.parse::<LspDataHeader>()
.unwrap();
let header = "Content-Length: 123\r\n\r\n".parse::<LspHeader>().unwrap();
assert_eq!(header.content_length, 123);
assert_eq!(header.content_type, None);
}
@ -923,14 +1000,14 @@ mod tests {
fn header_parse_should_support_optional_content_type() {
// Regular type
let header = "Content-Length: 123\r\nContent-Type: some content type\r\n\r\n"
.parse::<LspDataHeader>()
.parse::<LspHeader>()
.unwrap();
assert_eq!(header.content_length, 123);
assert_eq!(header.content_type.as_deref(), Some("some content type"));
// Type with colons
let header = "Content-Length: 123\r\nContent-Type: some:content:type\r\n\r\n"
.parse::<LspDataHeader>()
.parse::<LspHeader>()
.unwrap();
assert_eq!(header.content_length, 123);
assert_eq!(header.content_type.as_deref(), Some("some:content:type"));
@ -939,14 +1016,14 @@ mod tests {
#[test]
fn header_display_should_output_header_fields_with_appropriate_line_terminations() {
// Without content type
let header = LspDataHeader {
let header = LspHeader {
content_length: 123,
content_type: None,
};
assert_eq!(header.to_string(), "Content-Length: 123\r\n\r\n");
// With content type
let header = LspDataHeader {
let header = LspHeader {
content_length: 123,
content_type: Some(String::from("some type")),
};
@ -958,21 +1035,111 @@ mod tests {
#[test]
fn content_parse_should_succeed_if_valid_json() {
let content = "{\"hello\": \"world\"}".parse::<LspDataContent>().unwrap();
let content = "{\"hello\": \"world\"}".parse::<LspContent>().unwrap();
assert_eq!(content.as_ref(), &make_obj!({"hello": "world"}));
}
#[test]
fn content_parse_should_fail_if_invalid_json() {
assert!(
"not json".parse::<LspDataContent>().is_err(),
"not json".parse::<LspContent>().is_err(),
"Unexpectedly succeeded"
);
}
#[test]
fn content_display_should_output_content_as_json() {
let content = LspDataContent(make_obj!({"hello": "world"}));
let content = LspContent(make_obj!({"hello": "world"}));
assert_eq!(content.to_string(), "{\n \"hello\": \"world\"\n}");
}
#[test]
fn content_convert_local_scheme_to_distant_should_convert_keys_and_values() {
let mut content = LspContent(make_obj!({
"distant://key1": "file://value1",
"file://key2": "distant://value2",
"key3": ["file://value3", "distant://value4"],
"key4": {
"distant://key5": "file://value5",
"file://key6": "distant://value6",
"key7": [
{
"distant://key8": "file://value8",
"file://key9": "distant://value9",
}
]
},
"key10": null,
"key11": 123,
"key12": true,
}));
content.convert_local_scheme_to_distant();
assert_eq!(
content.0,
make_obj!({
"distant://key1": "distant://value1",
"distant://key2": "distant://value2",
"key3": ["distant://value3", "distant://value4"],
"key4": {
"distant://key5": "distant://value5",
"distant://key6": "distant://value6",
"key7": [
{
"distant://key8": "distant://value8",
"distant://key9": "distant://value9",
}
]
},
"key10": null,
"key11": 123,
"key12": true,
})
);
}
#[test]
fn content_convert_distant_scheme_to_local_should_convert_keys_and_values() {
let content = LspContent(make_obj!({
"distant://key1": "file://value1",
"file://key2": "distant://value2",
"key3": ["file://value3", "distant://value4"],
"key4": {
"distant://key5": "file://value5",
"file://key6": "distant://value6",
"key7": [
{
"distant://key8": "file://value8",
"file://key9": "distant://value9",
}
]
},
"key10": null,
"key11": 123,
"key12": true,
}));
content.convert_distant_scheme_to_local();
assert_eq!(
content.0,
make_obj!({
"file://key1": "file://value1",
"file://key2": "file://value2",
"key3": ["file://value3", "file://value4"],
"key4": {
"file://key5": "file://value5",
"file://key6": "file://value6",
"key7": [
{
"file://key8": "file://value8",
"file://key9": "file://value9",
}
]
},
"key10": null,
"key11": 123,
"key12": true,
})
);
}
}

@ -0,0 +1,197 @@
use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
use crate::core::{client::Session, net::DataStream};
use std::{
fmt::Write,
io::{self, Cursor, Read},
ops::{Deref, DerefMut},
};
mod data;
pub use data::*;
/// Represents an LSP server process on a remote machine
pub struct RemoteLspProcess {
inner: RemoteProcess,
pub stdin: Option<RemoteLspStdin>,
pub stdout: Option<RemoteLspStdout>,
pub stderr: Option<RemoteLspStderr>,
}
impl RemoteLspProcess {
/// Spawns the specified process on the remote machine using the given session, treating
/// the process like an LSP server
pub async fn spawn<T>(
session: Session<T>,
cmd: String,
args: Vec<String>,
) -> Result<Self, RemoteProcessError>
where
T: DataStream + 'static,
{
let mut inner = RemoteProcess::spawn(session, cmd, args).await?;
let stdin = inner.stdin.take().map(RemoteLspStdin::new);
let stdout = inner.stdout.take().map(RemoteLspStdout::new);
let stderr = inner.stderr.take().map(RemoteLspStderr::new);
Ok(RemoteLspProcess {
inner,
stdin,
stdout,
stderr,
})
}
}
impl Deref for RemoteLspProcess {
type Target = RemoteProcess;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for RemoteLspProcess {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
/// A handle to a remote LSP process' standard input (stdin)
pub struct RemoteLspStdin {
inner: RemoteStdin,
buf: Option<String>,
}
impl RemoteLspStdin {
pub fn new(inner: RemoteStdin) -> Self {
Self { inner, buf: None }
}
/// Writes data to the stdin of a specific remote process
pub async fn write(&mut self, data: &str) -> io::Result<()> {
let mut queue = Vec::new();
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.push_str(data),
None => self.buf = Some(data.to_string()),
}
// Read LSP messages from our internal buffer
let buf = self.buf.take().unwrap();
let mut cursor = Cursor::new(buf);
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
}
// Keep remainder of string not processed as LSP message in buffer
if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
self.buf = Some(buf);
}
// Process and then send out each LSP message in our queue
for mut data in queue {
// Convert distant:// to file://
data.mut_content().convert_distant_scheme_to_local();
self.inner.write(&data.to_string()).await?;
}
Ok(())
}
}
/// A handle to a remote LSP process' standard output (stdout)
pub struct RemoteLspStdout {
inner: RemoteStdout,
buf: Option<String>,
}
impl RemoteLspStdout {
pub fn new(inner: RemoteStdout) -> Self {
Self { inner, buf: None }
}
pub async fn read(&mut self) -> io::Result<String> {
let mut queue = Vec::new();
let data = self.inner.read().await?;
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.push_str(&data),
None => self.buf = Some(data),
}
// Read LSP messages from our internal buffer
let buf = self.buf.take().unwrap();
let mut cursor = Cursor::new(buf);
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
}
// Keep remainder of string not processed as LSP message in buffer
if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
self.buf = Some(buf);
}
// Process and then add each LSP message as output
let mut out = String::new();
for mut data in queue {
// Convert file:// to distant://
data.mut_content().convert_local_scheme_to_distant();
write!(&mut out, "{}", data).unwrap();
}
Ok(out)
}
}
/// A handle to a remote LSP process' stderr
pub struct RemoteLspStderr {
inner: RemoteStderr,
buf: Option<String>,
}
impl RemoteLspStderr {
pub fn new(inner: RemoteStderr) -> Self {
Self { inner, buf: None }
}
pub async fn read(&mut self) -> io::Result<String> {
let mut queue = Vec::new();
let data = self.inner.read().await?;
// Create or insert into our buffer
match &mut self.buf {
Some(buf) => buf.push_str(&data),
None => self.buf = Some(data),
}
// Read LSP messages from our internal buffer
let buf = self.buf.take().unwrap();
let mut cursor = Cursor::new(buf);
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
queue.push(data);
}
// Keep remainder of string not processed as LSP message in buffer
if (cursor.position() as usize) < cursor.get_ref().len() {
let mut buf = String::new();
cursor.read_to_string(&mut buf)?;
self.buf = Some(buf);
}
// Process and then add each LSP message as output
let mut out = String::new();
for mut data in queue {
// Convert file:// to distant://
data.mut_content().convert_local_scheme_to_distant();
write!(&mut out, "{}", data).unwrap();
}
Ok(out)
}
}

@ -0,0 +1,16 @@
mod lsp;
mod process;
mod session;
mod utils;
// TODO: Make wrappers around a connection to facilitate the types
// of engagements
//
// 1. Command -> Single request/response through a future
// 2. Proxy -> Does proc-run and waits until proc-done received,
// exposing a sender for stdin and receivers for stdout/stderr,
// and supporting a future await for completion with exit code
// 3.
pub use lsp::*;
pub use process::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
pub use session::*;

@ -0,0 +1,286 @@
use crate::core::{
client::{utils, Session},
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, RequestData, Response, ResponseData},
net::{DataStream, TransportError},
};
use derive_more::{Display, Error, From};
use tokio::{
io,
sync::mpsc,
task::{JoinError, JoinHandle},
};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
#[derive(Debug, Display, Error, From)]
pub enum RemoteProcessError {
/// When the process receives an unexpected response
BadResponse,
/// When attempting to relay stdout/stderr over channels, but the channels fail
ChannelDead,
/// When process is unable to read stdout/stderr from the server
/// fast enough, resulting in dropped data
Overloaded,
/// When the communication over the wire has issues
TransportError(TransportError),
/// When the stream of responses from the server closes without receiving
/// an indicator of the process' exit status
UnexpectedEof,
/// When attempting to wait on the remote process, but the internal task joining failed
WaitFailed(JoinError),
}
/// Represents a process on a remote machine
pub struct RemoteProcess {
/// Id of the process
id: usize,
/// Task that forwards stdin to the remote process by bundling it as stdin requests
req_task: JoinHandle<Result<(), RemoteProcessError>>,
/// Task that reads in new responses, which returns the success and optional
/// exit code once the process has completed
res_task: JoinHandle<Result<(bool, Option<i32>), RemoteProcessError>>,
/// Sender for stdin
pub stdin: Option<RemoteStdin>,
/// Receiver for stdout
pub stdout: Option<RemoteStdout>,
/// Receiver for stderr
pub stderr: Option<RemoteStderr>,
/// Sender for kill events
kill: mpsc::Sender<()>,
}
impl RemoteProcess {
/// Spawns the specified process on the remote machine using the given session
pub async fn spawn<T>(
mut session: Session<T>,
cmd: String,
args: Vec<String>,
) -> Result<Self, RemoteProcessError>
where
T: DataStream + 'static,
{
let tenant = utils::new_tenant();
// Submit our run request and wait for a response
let res = session
.send(Request::new(
tenant.as_str(),
vec![RequestData::ProcRun { cmd, args }],
))
.await?;
// We expect a singular response back
if res.payload.len() != 1 {
return Err(RemoteProcessError::BadResponse);
}
// Response should be proc starting
let id = match res.payload.into_iter().next().unwrap() {
ResponseData::ProcStart { id } => id,
_ => return Err(RemoteProcessError::BadResponse),
};
// Create channels for our stdin/stdout/stderr
let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
// Now we spawn a task to handle future responses that are async
// such as ProcStdout, ProcStderr, and ProcDone
let stream = session.to_response_broadcast_stream();
let res_task = tokio::spawn(async move {
process_incoming_responses(id, stream, stdout_tx, stderr_tx).await
});
// Spawn a task that takes stdin from our channel and forwards it to the remote process
let (kill_tx, kill_rx) = mpsc::channel(1);
let req_task = tokio::spawn(async move {
process_outgoing_requests(tenant, id, session, stdin_rx, kill_rx).await
});
Ok(Self {
id,
req_task,
res_task,
stdin: Some(RemoteStdin(stdin_tx)),
stdout: Some(RemoteStdout(stdout_rx)),
stderr: Some(RemoteStderr(stderr_rx)),
kill: kill_tx,
})
}
/// Returns the id of the running process
pub fn id(&self) -> usize {
self.id
}
/// Waits for the process to terminate, returning the success status and an optional exit code
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
self.res_task.await?
}
/// Aborts the process by forcing its response task to shutdown, which means that a call
/// to `wait` will return an error. Note that this does **not** send a kill request, so if
/// you want to be nice you should send the request before aborting.
pub fn abort(&self) {
self.req_task.abort();
self.res_task.abort();
}
/// Submits a kill request for the running process
pub async fn kill(&mut self) -> Result<(), RemoteProcessError> {
self.kill
.send(())
.await
.map_err(|_| RemoteProcessError::ChannelDead)?;
Ok(())
}
}
/// A handle to a remote process' standard input (stdin)
pub struct RemoteStdin(mpsc::Sender<String>);
impl RemoteStdin {
/// Writes data to the stdin of a specific remote process
pub async fn write(&mut self, data: impl Into<String>) -> io::Result<()> {
self.0
.send(data.into())
.await
.map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))
}
}
/// A handle to a remote process' standard output (stdout)
pub struct RemoteStdout(mpsc::Receiver<String>);
impl RemoteStdout {
/// Retrieves the latest stdout for a specific remote process
pub async fn read(&mut self) -> io::Result<String> {
self.0
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
}
}
/// A handle to a remote process' stderr
pub struct RemoteStderr(mpsc::Receiver<String>);
impl RemoteStderr {
/// Retrieves the latest stderr for a specific remote process
pub async fn read(&mut self) -> io::Result<String> {
self.0
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
}
}
/// Helper function that loops, processing outgoing stdin requests to a remote process as well as
/// supporting a kill request to terminate the remote process
async fn process_outgoing_requests<T>(
tenant: String,
id: usize,
mut session: Session<T>,
mut stdin_rx: mpsc::Receiver<String>,
mut kill_rx: mpsc::Receiver<()>,
) -> Result<(), RemoteProcessError>
where
T: DataStream,
{
loop {
tokio::select! {
data = stdin_rx.recv() => {
match data {
Some(data) => session.fire(
Request::new(
tenant.as_str(),
vec![RequestData::ProcStdin { id, data }]
)
).await?,
None => break Err(RemoteProcessError::ChannelDead),
}
}
msg = kill_rx.recv() => {
if msg.is_some() {
session
.fire(Request::new(
tenant.as_str(),
vec![RequestData::ProcKill { id }],
))
.await?;
break Ok(());
} else {
break Err(RemoteProcessError::ChannelDead);
}
}
}
}
}
/// Helper function that loops, processing incoming stdout & stderr requests from a remote process
async fn process_incoming_responses(
proc_id: usize,
mut stream: BroadcastStream<Response>,
stdout_tx: mpsc::Sender<String>,
stderr_tx: mpsc::Sender<String>,
) -> Result<(bool, Option<i32>), RemoteProcessError> {
let mut result = Err(RemoteProcessError::UnexpectedEof);
while let Some(res) = stream.next().await {
match res {
Ok(res) => {
// Check if any of the payload data is the termination
let exit_status = res.payload.iter().find_map(|data| match data {
ResponseData::ProcDone { id, success, code } if *id == proc_id => {
Some((*success, *code))
}
_ => None,
});
// Next, check for stdout/stderr and send them along our channels
// TODO: What should we do about unexpected data? For now, just ignore
for data in res.payload {
match data {
ResponseData::ProcStdout { id, data } if id == proc_id => {
if let Err(_) = stdout_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
}
ResponseData::ProcStderr { id, data } if id == proc_id => {
if let Err(_) = stderr_tx.send(data).await {
result = Err(RemoteProcessError::ChannelDead);
break;
}
}
_ => {}
}
}
// If we got a termination, then exit accordingly
if let Some((success, code)) = exit_status {
result = Ok((success, code));
break;
}
}
Err(_) => {
result = Err(RemoteProcessError::Overloaded);
break;
}
}
}
result
}

@ -10,84 +10,84 @@ use std::{
use tokio::{io, net::lookup_host};
#[derive(Debug, PartialEq, Eq)]
pub struct Session {
pub struct SessionInfo {
pub host: String,
pub port: u16,
pub auth_key: SecretKey,
}
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
pub enum SessionParseError {
pub enum SessionInfoParseError {
#[display(fmt = "Prefix of string is invalid")]
BadPrefix,
#[display(fmt = "Bad hex key for session")]
BadSessionHexKey,
BadHexKey,
#[display(fmt = "Invalid key for session")]
InvalidSessionKey,
InvalidKey,
#[display(fmt = "Invalid port for session")]
InvalidSessionPort,
InvalidPort,
#[display(fmt = "Missing address for session")]
MissingSessionAddr,
MissingAddr,
#[display(fmt = "Missing key for session")]
MissingSessionKey,
MissingKey,
#[display(fmt = "Missing port for session")]
MissingSessionPort,
MissingPort,
}
impl From<SessionParseError> for io::Error {
fn from(x: SessionParseError) -> Self {
impl From<SessionInfoParseError> for io::Error {
fn from(x: SessionInfoParseError) -> Self {
io::Error::new(io::ErrorKind::InvalidData, x)
}
}
impl FromStr for Session {
type Err = SessionParseError;
impl FromStr for SessionInfo {
type Err = SessionInfoParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut tokens = s.split(' ').take(5);
// First, validate that we have the appropriate prefix
if tokens.next().ok_or(SessionParseError::BadPrefix)? != "DISTANT" {
return Err(SessionParseError::BadPrefix);
if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "DISTANT" {
return Err(SessionInfoParseError::BadPrefix);
}
if tokens.next().ok_or(SessionParseError::BadPrefix)? != "DATA" {
return Err(SessionParseError::BadPrefix);
if tokens.next().ok_or(SessionInfoParseError::BadPrefix)? != "DATA" {
return Err(SessionInfoParseError::BadPrefix);
}
// Second, load up the address without parsing it
let host = tokens
.next()
.ok_or(SessionParseError::MissingSessionAddr)?
.ok_or(SessionInfoParseError::MissingAddr)?
.trim()
.to_string();
// Third, load up the port and parse it into a number
let port = tokens
.next()
.ok_or(SessionParseError::MissingSessionPort)?
.ok_or(SessionInfoParseError::MissingPort)?
.trim()
.parse::<u16>()
.map_err(|_| SessionParseError::InvalidSessionPort)?;
.map_err(|_| SessionInfoParseError::InvalidPort)?;
// Fourth, load up the key and convert it back into a secret key from a hex slice
let auth_key = SecretKey::from_slice(
&hex::decode(
tokens
.next()
.ok_or(SessionParseError::MissingSessionKey)?
.ok_or(SessionInfoParseError::MissingKey)?
.trim(),
)
.map_err(|_| SessionParseError::BadSessionHexKey)?,
.map_err(|_| SessionInfoParseError::BadHexKey)?,
)
.map_err(|_| SessionParseError::InvalidSessionKey)?;
.map_err(|_| SessionInfoParseError::InvalidKey)?;
Ok(Session {
Ok(SessionInfo {
host,
port,
auth_key,
@ -95,7 +95,7 @@ impl FromStr for Session {
}
}
impl Session {
impl SessionInfo {
/// Loads session from environment variables
pub fn from_environment() -> io::Result<Self> {
fn to_err(x: env::VarError) -> io::Error {
@ -159,40 +159,40 @@ impl Session {
}
/// Provides operations related to working with a session that is disk-based
pub struct SessionFile {
pub struct SessionInfoFile {
path: PathBuf,
session: Session,
session: SessionInfo,
}
impl AsRef<Path> for SessionFile {
impl AsRef<Path> for SessionInfoFile {
fn as_ref(&self) -> &Path {
self.as_path()
}
}
impl AsRef<Session> for SessionFile {
fn as_ref(&self) -> &Session {
impl AsRef<SessionInfo> for SessionInfoFile {
fn as_ref(&self) -> &SessionInfo {
self.as_session()
}
}
impl Deref for SessionFile {
type Target = Session;
impl Deref for SessionInfoFile {
type Target = SessionInfo;
fn deref(&self) -> &Self::Target {
&self.session
}
}
impl From<SessionFile> for Session {
fn from(sf: SessionFile) -> Self {
impl From<SessionInfoFile> for SessionInfo {
fn from(sf: SessionInfoFile) -> Self {
sf.session
}
}
impl SessionFile {
impl SessionInfoFile {
/// Creates a new inmemory pointer to a session and its file
pub fn new(path: impl Into<PathBuf>, session: Session) -> Self {
pub fn new(path: impl Into<PathBuf>, session: SessionInfo) -> Self {
Self {
path: path.into(),
session,
@ -205,7 +205,7 @@ impl SessionFile {
}
/// Returns a reference to the session
pub fn as_session(&self) -> &Session {
pub fn as_session(&self) -> &SessionInfo {
&self.session
}

@ -1,9 +1,8 @@
use crate::core::{
client::utils,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, Response},
net::{DataStream, SecretKey, Transport, TransportError, TransportWriteHalf},
session::Session,
utils,
net::{DataStream, InmemoryStream, SecretKey, Transport, TransportError, TransportWriteHalf},
};
use log::*;
use std::{
@ -20,14 +19,17 @@ use tokio::{
};
use tokio_stream::wrappers::BroadcastStream;
mod info;
pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError};
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
/// Represents a client that can make requests against a server
pub struct Client<T>
/// Represents a session with a remote server that can be used to send requests & receive responses
pub struct Session<T>
where
T: DataStream,
{
/// Underlying transport used by client
/// Underlying transport used by session
t_write: TransportWriteHalf<T::Write>,
/// Collection of callbacks to be invoked upon receiving a response to a request
@ -45,12 +47,21 @@ where
response_task: JoinHandle<()>,
}
impl Client<TcpStream> {
/// Connect to a remote TCP session
pub async fn tcp_connect(session: Session) -> io::Result<Self> {
let transport = Transport::<TcpStream>::connect(session).await?;
impl Session<InmemoryStream> {
/// Creates a session around an inmemory transport
pub async fn from_inmemory_transport(transport: Transport<InmemoryStream>) -> io::Result<Self> {
Self::inner_connect(transport).await
}
}
impl Session<TcpStream> {
/// Connect to a remote TCP server using the provided information
pub async fn tcp_connect(info: SessionInfo) -> io::Result<Self> {
let addr = info.to_socket_addr().await?;
let transport =
Transport::<TcpStream>::connect(addr, Some(Arc::new(info.auth_key))).await?;
debug!(
"Client has connected to {}",
"Session has been established with {}",
transport
.peer_addr()
.map(|x| x.to_string())
@ -59,16 +70,16 @@ impl Client<TcpStream> {
Self::inner_connect(transport).await
}
/// Connect to a remote TCP session, timing out after duration has passed
pub async fn tcp_connect_timeout(session: Session, duration: Duration) -> io::Result<Self> {
utils::timeout(duration, Self::tcp_connect(session))
/// Connect to a remote TCP server, timing out after duration has passed
pub async fn tcp_connect_timeout(info: SessionInfo, duration: Duration) -> io::Result<Self> {
utils::timeout(duration, Self::tcp_connect(info))
.await
.and_then(convert::identity)
}
}
#[cfg(unix)]
impl Client<tokio::net::UnixStream> {
impl Session<tokio::net::UnixStream> {
/// Connect to a proxy unix socket
pub async fn unix_connect(
path: impl AsRef<std::path::Path>,
@ -76,7 +87,7 @@ impl Client<tokio::net::UnixStream> {
) -> io::Result<Self> {
let transport = Transport::<tokio::net::UnixStream>::connect(path, auth_key).await?;
debug!(
"Client has connected to {}",
"Session has been established with {}",
transport
.peer_addr()
.map(|x| format!("{:?}", x))
@ -97,11 +108,11 @@ impl Client<tokio::net::UnixStream> {
}
}
impl<T> Client<T>
impl<T> Session<T>
where
T: DataStream,
{
/// Establishes a connection using the provided session
/// Establishes a connection using the provided transport
async fn inner_connect(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
@ -115,7 +126,7 @@ where
loop {
match t_read.receive::<Response>().await {
Ok(Some(res)) => {
trace!("Client got response: {:?}", res);
trace!("Incoming response: {:?}", res);
let maybe_callback = res
.origin_id
.as_ref()
@ -123,14 +134,14 @@ where
// If there is an origin to this response, trigger the callback
if let Some(tx) = maybe_callback {
trace!("Client has callback! Triggering!");
trace!("Callback exists for response! Triggering!");
if let Err(res) = tx.send(res) {
error!("Failed to trigger callback for response {}", res.id);
}
// Otherwise, this goes into the junk draw of response handlers
} else {
trace!("Client does not have callback! Broadcasting!");
trace!("Callback missing for response! Broadcasting!");
if let Err(x) = broadcast_2.send(res) {
error!("Failed to trigger broadcast: {}", x);
}
@ -154,13 +165,13 @@ where
})
}
/// Waits for the client to terminate, which results when the receiving end of the network
/// connection is closed (or the client is shutdown)
/// Waits for the session to terminate, which results when the receiving end of the network
/// connection is closed (or the session is shutdown)
pub async fn wait(self) -> Result<(), JoinError> {
self.response_task.await
}
/// Abort the client's current connection by forcing its response task to shutdown
/// Abort the session's current connection by forcing its response task to shutdown
pub fn abort(&self) {
self.response_task.abort()
}
@ -210,7 +221,7 @@ where
.and_then(convert::identity)
}
/// Clones a new instance of the broadcaster used by the client
/// Clones a new instance of the broadcaster used by the session
pub fn to_response_broadcaster(&self) -> broadcast::Sender<Response> {
self.broadcast.clone()
}
@ -232,14 +243,14 @@ mod tests {
use crate::core::{
constants::test::TENANT,
data::{RequestData, ResponseData},
net::transport::test::make_transport_pair,
net::test::make_transport_pair,
};
use std::time::Duration;
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = make_transport_pair();
let mut client = Client::inner_connect(t1).await.unwrap();
let mut session = Session::inner_connect(t1).await.unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
@ -250,7 +261,7 @@ mod tests {
}],
);
let (actual, _) = tokio::join!(client.send(req), t2.send(res.clone()));
let (actual, _) = tokio::join!(session.send(req), t2.send(res.clone()));
match actual {
Ok(actual) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
@ -260,10 +271,10 @@ mod tests {
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = make_transport_pair();
let mut client = Client::inner_connect(t1).await.unwrap();
let mut session = Session::inner_connect(t1).await.unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match client.send_timeout(req, Duration::from_millis(30)).await {
match session.send_timeout(req, Duration::from_millis(30)).await {
Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
x => panic!("Unexpected response: {:?}", x),
}
@ -275,10 +286,10 @@ mod tests {
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = make_transport_pair();
let mut client = Client::inner_connect(t1).await.unwrap();
let mut session = Session::inner_connect(t1).await.unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match client.fire(req).await {
match session.fire(req).await {
Ok(_) => {}
x => panic!("Unexpected response: {:?}", x),
}

@ -0,0 +1,18 @@
use std::{future::Future, time::Duration};
use tokio::{io, time};
// Generates a new tenant name
pub fn new_tenant() -> String {
format!("tenant_{}{}", rand::random::<u16>(), rand::random::<u8>())
}
// Wraps a future in a tokio timeout call, transforming the error into
// an io error
pub async fn timeout<T, F>(d: Duration, f: F) -> io::Result<T>
where
F: Future<Output = T>,
{
time::timeout(d, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
}

@ -1,7 +1,5 @@
pub mod client;
pub mod constants;
pub mod data;
pub mod lsp;
pub mod net;
pub mod session;
pub mod state;
pub mod utils;
pub mod server;

@ -1,8 +1,5 @@
mod transport;
pub use transport::*;
mod client;
pub use client::Client;
// Re-export commonly-used orion structs
pub use orion::aead::SecretKey;

@ -70,7 +70,9 @@ impl AsyncWrite for InmemoryStream {
}
}
/// Read portion of an inmemory channel
pub struct InmemoryStreamReadHalf(mpsc::Receiver<Vec<u8>>);
impl AsyncRead for InmemoryStreamReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
@ -87,7 +89,9 @@ impl AsyncRead for InmemoryStreamReadHalf {
}
}
/// Write portion of an inmemory channel
pub struct InmemoryStreamWriteHalf(mpsc::Sender<Vec<u8>>);
impl AsyncWrite for InmemoryStreamWriteHalf {
fn poll_write(
self: Pin<&mut Self>,

@ -1,11 +1,10 @@
use super::{DataStream, Transport};
use crate::core::session::Session;
use super::{DataStream, SecretKey, Transport};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io,
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
TcpStream, ToSocketAddrs,
},
};
@ -28,10 +27,13 @@ impl Transport<TcpStream> {
/// Establishes a connection using the provided session and performs a handshake to establish
/// means of encryption, returning a transport ready to communicate with the other side
///
/// TCP Streams will always use a session's authentication key
pub async fn connect(session: Session) -> io::Result<Self> {
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
Self::from_handshake(stream, Some(Arc::new(session.auth_key))).await
/// Takes an optional authentication key
pub async fn connect(
addrs: impl ToSocketAddrs,
auth_key: Option<Arc<SecretKey>>,
) -> io::Result<Self> {
let stream = TcpStream::connect(addrs).await?;
Self::from_handshake(stream, auth_key).await
}
/// Returns the address of the peer the transport is connected to

@ -1,5 +1,4 @@
use super::{DataStream, Transport};
use orion::aead::SecretKey;
use super::{DataStream, SecretKey, Transport};
use std::sync::Arc;
use tokio::{
io,

@ -3,7 +3,7 @@ use crate::core::{
data::{
self, DirEntry, FileType, Request, RequestData, Response, ResponseData, RunningProcess,
},
state::{Process, ServerState},
server::state::{Process, State},
};
use derive_more::{Display, Error, From};
use futures::future;
@ -24,7 +24,7 @@ use tokio::{
use walkdir::WalkDir;
pub type Reply = mpsc::Sender<Response>;
type HState = Arc<Mutex<ServerState<SocketAddr>>>;
type HState = Arc<Mutex<State<SocketAddr>>>;
#[derive(Debug, Display, Error, From)]
pub enum ServerError {

@ -0,0 +1,222 @@
mod handler;
mod port;
mod state;
mod utils;
pub use port::{PortRange, PortRangeParseError};
use state::State;
use crate::core::{
data::{Request, Response},
net::{SecretKey, Transport, TransportReadHalf, TransportWriteHalf},
};
use log::*;
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use tokio::{
io,
net::{tcp, TcpListener, TcpStream},
runtime::Handle,
sync::{mpsc, Mutex, Notify},
task::{JoinError, JoinHandle},
time::Duration,
};
/// Represents a server that listens for requests, processes them, and sends responses
pub struct Server {
port: u16,
state: Arc<Mutex<State<SocketAddr>>>,
auth_key: Arc<SecretKey>,
notify: Arc<Notify>,
conn_task: JoinHandle<()>,
}
impl Server {
pub async fn bind(
addr: IpAddr,
port: PortRange,
shutdown_after: Option<Duration>,
max_msg_capacity: usize,
) -> io::Result<Self> {
debug!("Binding to {} in range {}", addr, port);
let listener = TcpListener::bind(port.make_socket_addrs(addr).as_slice()).await?;
let port = listener.local_addr()?.port();
debug!("Bound to port: {}", port);
// Build our state for the server
let state: Arc<Mutex<State<SocketAddr>>> = Arc::new(Mutex::new(State::default()));
let auth_key = Arc::new(SecretKey::default());
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
// Spawn our connection task
let state_2 = Arc::clone(&state);
let auth_key_2 = Arc::clone(&auth_key);
let notify_2 = Arc::clone(&notify);
let conn_task = tokio::spawn(async move {
connection_loop(
listener,
state_2,
auth_key_2,
ct,
notify_2,
max_msg_capacity,
)
.await
});
Ok(Self {
port,
state,
auth_key,
notify,
conn_task,
})
}
/// Returns the port this server is bound to
pub fn port(&self) -> u16 {
self.port
}
/// Returns a string representing the auth key as hex
pub fn to_unprotected_hex_auth_key(&self) -> String {
hex::encode(self.auth_key.unprotected_as_bytes())
}
/// Waits for the server to terminate
pub async fn wait(self) -> Result<(), JoinError> {
self.conn_task.await
}
/// Shutdown the server
pub fn shutdown(&self) {
self.notify.notify_one()
}
}
async fn connection_loop(
listener: TcpListener,
state: Arc<Mutex<State<SocketAddr>>>,
auth_key: Arc<SecretKey>,
tracker: Arc<Mutex<utils::ConnTracker>>,
notify: Arc<Notify>,
max_msg_capacity: usize,
) {
loop {
tokio::select! {
result = listener.accept() => {match result {
Ok((conn, addr)) => {
if let Err(x) = on_new_conn(
conn,
addr,
Arc::clone(&state),
Arc::clone(&auth_key),
Arc::clone(&tracker),
max_msg_capacity
).await {
error!("<Conn @ {}> Failed handshake: {}", addr, x);
}
}
Err(x) => {
error!("Listener failed: {}", x);
break;
}
}}
_ = notify.notified() => {
warn!("Reached shutdown timeout, so terminating");
break;
}
}
}
}
/// Processes a new connection, performing a handshake, and then spawning two tasks to handle
/// input and output, returning join handles for the input and output tasks respectively
async fn on_new_conn(
conn: TcpStream,
addr: SocketAddr,
state: Arc<Mutex<State<SocketAddr>>>,
auth_key: Arc<SecretKey>,
tracker: Arc<Mutex<utils::ConnTracker>>,
max_msg_capacity: usize,
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)> {
// Establish a proper connection via a handshake,
// discarding the connection otherwise
let transport = Transport::from_handshake(conn, Some(auth_key)).await?;
// Split the transport into read and write halves so we can handle input
// and output concurrently
let (t_read, t_write) = transport.into_split();
let (tx, rx) = mpsc::channel(max_msg_capacity);
let ct_2 = Arc::clone(&tracker);
// Spawn a new task that loops to handle requests from the client
let req_task = tokio::spawn({
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
ct_2.lock().await.increment();
f.await;
state.lock().await.cleanup_client(addr).await;
ct_2.lock().await.decrement();
}
});
// Spawn a new task that loops to handle responses to the client
let res_task = tokio::spawn(async move { response_loop(addr, t_write, rx).await });
Ok((req_task, res_task))
}
/// Repeatedly reads in new requests, processes them, and sends their responses to the
/// response loop
async fn request_loop(
addr: SocketAddr,
state: Arc<Mutex<State<SocketAddr>>>,
mut transport: TransportReadHalf<tcp::OwnedReadHalf>,
tx: mpsc::Sender<Response>,
) {
loop {
match transport.receive::<Request>().await {
Ok(Some(req)) => {
debug!(
"<Conn @ {}> Received request of type{} {}",
addr,
if req.payload.len() > 1 { "s" } else { "" },
req.to_payload_type_string()
);
if let Err(x) = handler::process(addr, Arc::clone(&state), req, tx.clone()).await {
error!("<Conn @ {}> {}", addr, x);
break;
}
}
Ok(None) => {
info!("<Conn @ {}> Closed connection", addr);
break;
}
Err(x) => {
error!("<Conn @ {}> {}", addr, x);
break;
}
}
}
}
/// Repeatedly sends responses out over the wire
async fn response_loop(
addr: SocketAddr,
mut transport: TransportWriteHalf<tcp::OwnedWriteHalf>,
mut rx: mpsc::Receiver<Response>,
) {
while let Some(res) = rx.recv().await {
if let Err(x) = transport.send(res).await {
error!("<Conn @ {}> {}", addr, x);
break;
}
}
}

@ -0,0 +1,66 @@
use derive_more::{Display, Error};
use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
};
/// Represents some range of ports
#[derive(Clone, Debug, Display, PartialEq, Eq)]
#[display(
fmt = "{}{}",
start,
"end.as_ref().map(|end| format!(\"[:{}]\", end)).unwrap_or_default()"
)]
pub struct PortRange {
pub start: u16,
pub end: Option<u16>,
}
impl PortRange {
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
let mut socket_addrs = Vec::new();
let addr = addr.into();
for port in self.start..=self.end.unwrap_or(self.start) {
socket_addrs.push(SocketAddr::from((addr, port)));
}
socket_addrs
}
}
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
pub enum PortRangeParseError {
InvalidPort,
MissingPort,
}
impl FromStr for PortRange {
type Err = PortRangeParseError;
/// Parses PORT into single range or PORT1:PORTN into full range
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut tokens = s.trim().split(':');
let start = tokens
.next()
.ok_or(PortRangeParseError::MissingPort)?
.parse::<u16>()
.map_err(|_| PortRangeParseError::InvalidPort)?;
let end = if let Some(token) = tokens.next() {
Some(
token
.parse::<u16>()
.map_err(|_| PortRangeParseError::InvalidPort)?,
)
} else {
None
};
if tokens.next().is_some() {
return Err(PortRangeParseError::InvalidPort);
}
Ok(Self { start, end })
}
}

@ -3,7 +3,7 @@ use std::{collections::HashMap, fmt::Debug, hash::Hash};
use tokio::sync::{mpsc, oneshot};
/// Holds state related to multiple clients managed by a server
pub struct ServerState<ClientId>
pub struct State<ClientId>
where
ClientId: Debug + Hash + PartialEq + Eq,
{
@ -14,7 +14,7 @@ where
client_processes: HashMap<ClientId, Vec<usize>>,
}
impl<ClientId> ServerState<ClientId>
impl<ClientId> State<ClientId>
where
ClientId: Debug + Hash + PartialEq + Eq,
{
@ -50,7 +50,7 @@ where
}
}
impl<ClientId> Default for ServerState<ClientId>
impl<ClientId> Default for State<ClientId>
where
ClientId: Debug + Hash + PartialEq + Eq,
{

@ -1,33 +1,11 @@
use log::*;
use std::{
future::Future,
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
};
use std::{sync::Arc, time::Duration};
use tokio::{
io,
runtime::Handle,
sync::{Mutex, Notify},
time::{self, Instant},
};
// Generates a new tenant name
pub fn new_tenant() -> String {
format!("tenant_{}{}", rand::random::<u16>(), rand::random::<u8>())
}
// Wraps a future in a tokio timeout call, transforming the error into
// an io error
pub async fn timeout<T, F>(d: Duration, f: F) -> io::Result<T>
where
F: Future<Output = T>,
{
time::timeout(d, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
}
pub struct ConnTracker {
time: Instant,
cnt: usize,
@ -121,53 +99,3 @@ pub fn new_shutdown_task(
(ct, notify)
}
/// Wraps a string to provide some friendly read and write methods
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct StringBuf(String);
impl StringBuf {
pub fn new() -> Self {
Self(String::new())
}
/// Consumes data within the buffer that represent full lines (end with a newline) and returns
/// the string containing those lines.
///
/// The remaining buffer contains are returned as the second part of the tuple
pub fn into_full_lines(mut self) -> (Option<String>, StringBuf) {
match self.rfind('\n') {
Some(idx) => {
let remaining = self.0.split_off(idx + 1);
(Some(self.0), Self(remaining))
}
None => (None, self),
}
}
}
impl From<String> for StringBuf {
fn from(x: String) -> Self {
Self(x)
}
}
impl From<StringBuf> for String {
fn from(x: StringBuf) -> Self {
x.0
}
}
impl Deref for StringBuf {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for StringBuf {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

@ -0,0 +1,2 @@
mod distant;
pub use distant::{PortRange, PortRangeParseError, Server as DistantServer};

@ -4,80 +4,9 @@ mod core;
pub use self::core::{data, net};
use log::error;
/// Represents an error that can be converted into an exit code
pub trait ExitCodeError: std::error::Error {
fn to_exit_code(&self) -> ExitCode;
fn to_i32(&self) -> i32 {
self.to_exit_code() as i32
}
}
impl ExitCodeError for std::io::Error {
fn to_exit_code(&self) -> ExitCode {
use std::io::ErrorKind;
match self.kind() {
ErrorKind::ConnectionAborted
| ErrorKind::ConnectionRefused
| ErrorKind::ConnectionReset
| ErrorKind::NotConnected => ExitCode::Unavailable,
ErrorKind::InvalidData => ExitCode::DataErr,
ErrorKind::TimedOut => ExitCode::TempFail,
_ => ExitCode::IoError,
}
}
}
impl ExitCodeError for core::net::TransportError {
fn to_exit_code(&self) -> ExitCode {
match self {
core::net::TransportError::IoError(x) => x.to_exit_code(),
_ => ExitCode::Protocol,
}
}
}
impl<T: ExitCodeError + 'static> From<T> for Box<dyn ExitCodeError> {
fn from(x: T) -> Self {
Box::new(x)
}
}
/// Exit codes following https://www.freebsd.org/cgi/man.cgi?query=sysexits&sektion=3
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum ExitCode {
/// EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI
Usage = 64,
/// EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data
/// is bad
DataErr = 65,
/// EX_NOINPUT (66) - being used when not getting expected data from launch
NoInput = 66,
/// EX_NOHOST (68) - being used when failed to resolve a host
NoHost = 68,
/// EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem
Unavailable = 69,
/// EX_OSERR (71) - being used when fork failed
OsErr = 71,
/// EX_IOERR (74) - being used as catchall for IO errors
IoError = 74,
/// EX_TEMPFAIL (75) - being used when we get a timeout
TempFail = 75,
/// EX_PROTOCOL (76) - being used as catchall for transport errors
Protocol = 76,
}
/// Main entrypoint into the program
pub fn run() {
let opt = cli::opt::Opt::load();
let opt = cli::Opt::load();
let logger = init_logging(&opt.common);
if let Err(x) = opt.subcommand.run(opt.common) {
error!("Exiting due to error: {}", x);
@ -88,7 +17,7 @@ pub fn run() {
}
}
fn init_logging(opt: &cli::opt::CommonOpt) -> flexi_logger::LoggerHandle {
fn init_logging(opt: &cli::CommonOpt) -> flexi_logger::LoggerHandle {
use flexi_logger::{FileSpec, LevelFilter, LogSpecification, Logger};
let module = "distant";

Loading…
Cancel
Save