mirror of https://github.com/chipsenkbeil/distant
Big refactor that is not finished
parent
2b23cd379c
commit
1ca3cd7859
@ -0,0 +1,51 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// Wraps a string to provide some friendly read and write methods
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct StringBuf(String);
|
||||
|
||||
impl StringBuf {
|
||||
pub fn new() -> Self {
|
||||
Self(String::new())
|
||||
}
|
||||
|
||||
/// Consumes data within the buffer that represent full lines (end with a newline) and returns
|
||||
/// the string containing those lines.
|
||||
///
|
||||
/// The remaining buffer contains are returned as the second part of the tuple
|
||||
pub fn into_full_lines(mut self) -> (Option<String>, StringBuf) {
|
||||
match self.rfind('\n') {
|
||||
Some(idx) => {
|
||||
let remaining = self.0.split_off(idx + 1);
|
||||
(Some(self.0), Self(remaining))
|
||||
}
|
||||
None => (None, self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for StringBuf {
|
||||
fn from(x: String) -> Self {
|
||||
Self(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StringBuf> for String {
|
||||
fn from(x: StringBuf) -> Self {
|
||||
x.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for StringBuf {
|
||||
type Target = String;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for StringBuf {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
@ -0,0 +1,75 @@
|
||||
use crate::core::net::TransportError;
|
||||
|
||||
/// Exit codes following https://www.freebsd.org/cgi/man.cgi?query=sysexits&sektion=3
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ExitCode {
|
||||
/// EX_USAGE (64) - being used when arguments missing or bad arguments provided to CLI
|
||||
Usage = 64,
|
||||
|
||||
/// EX_DATAERR (65) - being used when bad data received not in UTF-8 format or transport data
|
||||
/// is bad
|
||||
DataErr = 65,
|
||||
|
||||
/// EX_NOINPUT (66) - being used when not getting expected data from launch
|
||||
NoInput = 66,
|
||||
|
||||
/// EX_NOHOST (68) - being used when failed to resolve a host
|
||||
NoHost = 68,
|
||||
|
||||
/// EX_UNAVAILABLE (69) - being used when IO error encountered where connection is problem
|
||||
Unavailable = 69,
|
||||
|
||||
/// EX_SOFTWARE (70) - being used for internal errors that can occur like joining a task
|
||||
Software = 70,
|
||||
|
||||
/// EX_OSERR (71) - being used when fork failed
|
||||
OsErr = 71,
|
||||
|
||||
/// EX_IOERR (74) - being used as catchall for IO errors
|
||||
IoError = 74,
|
||||
|
||||
/// EX_TEMPFAIL (75) - being used when we get a timeout
|
||||
TempFail = 75,
|
||||
|
||||
/// EX_PROTOCOL (76) - being used as catchall for transport errors
|
||||
Protocol = 76,
|
||||
}
|
||||
|
||||
/// Represents an error that can be converted into an exit code
|
||||
pub trait ExitCodeError: std::error::Error {
|
||||
fn to_exit_code(&self) -> ExitCode;
|
||||
|
||||
fn to_i32(&self) -> i32 {
|
||||
self.to_exit_code() as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl ExitCodeError for std::io::Error {
|
||||
fn to_exit_code(&self) -> ExitCode {
|
||||
use std::io::ErrorKind;
|
||||
match self.kind() {
|
||||
ErrorKind::ConnectionAborted
|
||||
| ErrorKind::ConnectionRefused
|
||||
| ErrorKind::ConnectionReset
|
||||
| ErrorKind::NotConnected => ExitCode::Unavailable,
|
||||
ErrorKind::InvalidData => ExitCode::DataErr,
|
||||
ErrorKind::TimedOut => ExitCode::TempFail,
|
||||
_ => ExitCode::IoError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExitCodeError for TransportError {
|
||||
fn to_exit_code(&self) -> ExitCode {
|
||||
match self {
|
||||
TransportError::IoError(x) => x.to_exit_code(),
|
||||
_ => ExitCode::Protocol,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ExitCodeError + 'static> From<T> for Box<dyn ExitCodeError> {
|
||||
fn from(x: T) -> Self {
|
||||
Box::new(x)
|
||||
}
|
||||
}
|
@ -1,2 +1,10 @@
|
||||
pub mod opt;
|
||||
mod buf;
|
||||
mod exit;
|
||||
mod opt;
|
||||
mod output;
|
||||
mod session;
|
||||
mod subcommand;
|
||||
|
||||
pub use exit::{ExitCode, ExitCodeError};
|
||||
pub use opt::*;
|
||||
pub use output::ResponseOut;
|
||||
|
@ -0,0 +1,178 @@
|
||||
use crate::{
|
||||
cli::Format,
|
||||
core::data::{Error, Response, ResponseData},
|
||||
};
|
||||
use log::*;
|
||||
use std::io;
|
||||
|
||||
/// Represents the output content and destination
|
||||
pub enum ResponseOut {
|
||||
Stdout(String),
|
||||
StdoutLine(String),
|
||||
Stderr(String),
|
||||
StderrLine(String),
|
||||
None,
|
||||
}
|
||||
|
||||
impl ResponseOut {
|
||||
/// Create a new output message for the given response based on the specified format
|
||||
pub fn new(format: Format, res: Response) -> io::Result<ResponseOut> {
|
||||
let payload_cnt = res.payload.len();
|
||||
|
||||
Ok(match format {
|
||||
Format::Json => ResponseOut::StdoutLine(format!(
|
||||
"{}",
|
||||
serde_json::to_string(&res)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?
|
||||
)),
|
||||
|
||||
// NOTE: For shell, we assume a singular entry in the response's payload
|
||||
Format::Shell if payload_cnt != 1 => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"Got {} entries in payload data, but shell expects exactly 1",
|
||||
payload_cnt
|
||||
),
|
||||
))
|
||||
}
|
||||
Format::Shell => format_shell(res.payload.into_iter().next().unwrap()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Consumes the output message, printing it based on its configuration
|
||||
pub fn print(self) {
|
||||
match self {
|
||||
Self::Stdout(x) => {
|
||||
// NOTE: Because we are not including a newline in the output,
|
||||
// it is not guaranteed to be written out. In the case of
|
||||
// LSP protocol, the JSON content is not followed by a
|
||||
// newline and was not picked up when the response was
|
||||
// sent back to the client; so, we need to manually flush
|
||||
use std::io::Write;
|
||||
print!("{}", x);
|
||||
if let Err(x) = std::io::stdout().lock().flush() {
|
||||
error!("Failed to flush stdout: {}", x);
|
||||
}
|
||||
}
|
||||
Self::StdoutLine(x) => println!("{}", x),
|
||||
Self::Stderr(x) => {
|
||||
use std::io::Write;
|
||||
eprint!("{}", x);
|
||||
if let Err(x) = std::io::stderr().lock().flush() {
|
||||
error!("Failed to flush stderr: {}", x);
|
||||
}
|
||||
}
|
||||
Self::StderrLine(x) => eprintln!("{}", x),
|
||||
Self::None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_shell(data: ResponseData) -> ResponseOut {
|
||||
match data {
|
||||
ResponseData::Ok => ResponseOut::None,
|
||||
ResponseData::Error(Error { kind, description }) => {
|
||||
ResponseOut::StderrLine(format!("Failed ({}): '{}'.", kind, description))
|
||||
}
|
||||
ResponseData::Blob { data } => {
|
||||
ResponseOut::StdoutLine(String::from_utf8_lossy(&data).to_string())
|
||||
}
|
||||
ResponseData::Text { data } => ResponseOut::StdoutLine(data),
|
||||
ResponseData::DirEntries { entries, .. } => ResponseOut::StdoutLine(format!(
|
||||
"{}",
|
||||
entries
|
||||
.into_iter()
|
||||
.map(|entry| {
|
||||
format!(
|
||||
"{}{}",
|
||||
entry.path.as_os_str().to_string_lossy(),
|
||||
if entry.file_type.is_dir() {
|
||||
// NOTE: This can be different from the server if
|
||||
// the server OS is unix and the client is
|
||||
// not or vice versa; for now, this doesn't
|
||||
// matter as we only support unix-based
|
||||
// operating systems, but something to keep
|
||||
// in mind
|
||||
std::path::MAIN_SEPARATOR.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n"),
|
||||
)),
|
||||
ResponseData::Exists(exists) => {
|
||||
if exists {
|
||||
ResponseOut::StdoutLine("Does exist.".to_string())
|
||||
} else {
|
||||
ResponseOut::StdoutLine("Does not exist.".to_string())
|
||||
}
|
||||
}
|
||||
ResponseData::Metadata {
|
||||
canonicalized_path,
|
||||
file_type,
|
||||
len,
|
||||
readonly,
|
||||
accessed,
|
||||
created,
|
||||
modified,
|
||||
} => ResponseOut::StdoutLine(format!(
|
||||
concat!(
|
||||
"{}",
|
||||
"Type: {}\n",
|
||||
"Len: {}\n",
|
||||
"Readonly: {}\n",
|
||||
"Created: {}\n",
|
||||
"Last Accessed: {}\n",
|
||||
"Last Modified: {}",
|
||||
),
|
||||
canonicalized_path
|
||||
.map(|p| format!("Canonicalized Path: {:?}\n", p))
|
||||
.unwrap_or_default(),
|
||||
file_type.as_ref(),
|
||||
len,
|
||||
readonly,
|
||||
created.unwrap_or_default(),
|
||||
accessed.unwrap_or_default(),
|
||||
modified.unwrap_or_default(),
|
||||
)),
|
||||
ResponseData::ProcEntries { entries } => ResponseOut::StdoutLine(format!(
|
||||
"{}",
|
||||
entries
|
||||
.into_iter()
|
||||
.map(|entry| format!("{}: {} {}", entry.id, entry.cmd, entry.args.join(" ")))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n"),
|
||||
)),
|
||||
ResponseData::ProcStart { .. } => ResponseOut::None,
|
||||
ResponseData::ProcStdout { data, .. } => ResponseOut::Stdout(data),
|
||||
ResponseData::ProcStderr { data, .. } => ResponseOut::Stderr(data),
|
||||
ResponseData::ProcDone { id, success, code } => {
|
||||
if success {
|
||||
ResponseOut::None
|
||||
} else if let Some(code) = code {
|
||||
ResponseOut::StderrLine(format!("Proc {} failed with code {}", id, code))
|
||||
} else {
|
||||
ResponseOut::StderrLine(format!("Proc {} failed", id))
|
||||
}
|
||||
}
|
||||
ResponseData::SystemInfo {
|
||||
family,
|
||||
os,
|
||||
arch,
|
||||
current_dir,
|
||||
main_separator,
|
||||
} => ResponseOut::StdoutLine(format!(
|
||||
concat!(
|
||||
"Family: {:?}\n",
|
||||
"Operating System: {:?}\n",
|
||||
"Arch: {:?}\n",
|
||||
"Cwd: {:?}\n",
|
||||
"Path Sep: {:?}",
|
||||
),
|
||||
family, os, arch, current_dir, main_separator,
|
||||
)),
|
||||
}
|
||||
}
|
@ -0,0 +1,155 @@
|
||||
use crate::{
|
||||
cli::{buf::StringBuf, Format, ResponseOut},
|
||||
core::{
|
||||
client::Session,
|
||||
constants::MAX_PIPE_CHUNK_SIZE,
|
||||
data::{Request, Response},
|
||||
net::DataStream,
|
||||
},
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
io::{self, BufReader, Read},
|
||||
sync::Arc,
|
||||
thread,
|
||||
};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
|
||||
|
||||
/// Represents a wrapper around a session that provides CLI functionality such as reading from
|
||||
/// stdin and piping results back out to stdout
|
||||
pub struct CliSession<T>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
inner: Session<T>,
|
||||
}
|
||||
|
||||
impl<T> CliSession<T>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
pub fn new(inner: Session<T>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
// TODO TODO TODO:
|
||||
//
|
||||
// 1. Change watch to broadcast if going to use in both loops, otherwise just make
|
||||
// it an mpsc otherwise
|
||||
// 2. Need to provide outgoing requests function with logic from inner.rs to create a request
|
||||
// based on the format (json or shell), where json uses serde_json::from_str and shell
|
||||
// uses Request::new(tenant.as_str(), vec![RequestData::from_iter_safe(...)])
|
||||
// 3. Need to add a wait method to block on the running tasks
|
||||
// 4. Need to add an abort method to abort the tasks
|
||||
// 5. Is there any way to deal with the blocking thread for stdin to kill it? This isn't a big
|
||||
// deal as the shutdown would only be happening on client termination anyway, but still...
|
||||
|
||||
/// Helper function that loops, processing incoming responses not tied to a request to be sent out
|
||||
/// over stdout/stderr
|
||||
async fn process_incoming_responses(
|
||||
mut stream: BroadcastStream<Response>,
|
||||
format: Format,
|
||||
mut exit: watch::Receiver<bool>,
|
||||
) -> io::Result<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = stream.next() => {
|
||||
match res {
|
||||
Some(Ok(res)) => ResponseOut::new(format, res)?.print(),
|
||||
Some(Err(x)) => return Err(io::Error::new(io::ErrorKind::BrokenPipe, x)),
|
||||
None => return Ok(()),
|
||||
}
|
||||
}
|
||||
_ = exit.changed() => {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function that loops, processing outgoing requests created from stdin, and printing out
|
||||
/// responses
|
||||
async fn process_outgoing_requests<T, F>(
|
||||
mut session: Session<T>,
|
||||
mut stdin_rx: mpsc::Receiver<String>,
|
||||
format: Format,
|
||||
map_line: F,
|
||||
) where
|
||||
T: DataStream,
|
||||
F: Fn(&str) -> io::Result<Request>,
|
||||
{
|
||||
let mut buf = StringBuf::new();
|
||||
|
||||
while let Some(data) = stdin_rx.recv().await {
|
||||
// Update our buffer with the new data and split it into concrete lines and remainder
|
||||
buf.push_str(&data);
|
||||
let (lines, new_buf) = buf.into_full_lines();
|
||||
buf = new_buf;
|
||||
|
||||
// For each complete line, parse into a request
|
||||
if let Some(lines) = lines {
|
||||
for line in lines.lines() {
|
||||
trace!("Processing line: {:?}", line);
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match map_line(line) {
|
||||
Ok(req) => match session.send(req).await {
|
||||
Ok(res) => match ResponseOut::new(format, res) {
|
||||
Ok(out) => out.print(),
|
||||
Err(x) => error!("Failed to format response: {}", x),
|
||||
},
|
||||
Err(x) => {
|
||||
error!("Failed to send request: {}", x)
|
||||
}
|
||||
},
|
||||
Err(x) => {
|
||||
error!("Failed to parse line: {}", x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new thread that performs stdin reads in a blocking fashion, returning
|
||||
/// a handle to the thread and a receiver that will be sent input as it becomes available
|
||||
fn spawn_stdin_reader() -> (thread::JoinHandle<()>, mpsc::Receiver<String>) {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
|
||||
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
|
||||
// pass the results to a separate async handler to forward to the remote process
|
||||
let handle = thread::spawn(move || {
|
||||
let mut stdin = BufReader::new(io::stdin());
|
||||
|
||||
// Maximum chunk that we expect to read at any one time
|
||||
let mut buf = [0; MAX_PIPE_CHUNK_SIZE];
|
||||
|
||||
loop {
|
||||
match stdin.read(&mut buf) {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => {
|
||||
match String::from_utf8(buf[..n].to_vec()) {
|
||||
Ok(text) => {
|
||||
if let Err(x) = tx.blocking_send(text) {
|
||||
error!(
|
||||
"Failed to pass along stdin to be sent to remote process: {}",
|
||||
x
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Input over stdin is invalid: {}", x);
|
||||
}
|
||||
}
|
||||
thread::yield_now();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(handle, rx)
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
use crate::{
|
||||
cli::{
|
||||
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
|
||||
ExitCode, ExitCodeError,
|
||||
},
|
||||
core::server::DistantServer,
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use fork::{daemon, Fork};
|
||||
use log::*;
|
||||
use tokio::{io, task::JoinError};
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum Error {
|
||||
ConvertToIpAddrError(ConvertToIpAddrError),
|
||||
ForkError,
|
||||
IoError(io::Error),
|
||||
JoinError(JoinError),
|
||||
}
|
||||
|
||||
impl ExitCodeError for Error {
|
||||
fn to_exit_code(&self) -> ExitCode {
|
||||
match self {
|
||||
Self::ConvertToIpAddrError(_) => ExitCode::NoHost,
|
||||
Self::ForkError => ExitCode::OsErr,
|
||||
Self::IoError(x) => x.to_exit_code(),
|
||||
Self::JoinError(_) => ExitCode::Software,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
if cmd.daemon {
|
||||
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
|
||||
match daemon(false, true) {
|
||||
Ok(Fork::Child) => {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, true).await })?;
|
||||
}
|
||||
Ok(Fork::Parent(pid)) => {
|
||||
info!("[distant detached, pid = {}]", pid);
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
Err(_) => return Err(Error::ForkError),
|
||||
}
|
||||
} else {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, false).await })?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
|
||||
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
|
||||
let socket_addrs = cmd.port.make_socket_addrs(addr);
|
||||
let shutdown_after = cmd.to_shutdown_after_duration();
|
||||
|
||||
// If specified, change the current working directory of this program
|
||||
if let Some(path) = cmd.current_dir.as_ref() {
|
||||
debug!("Setting current directory to {:?}", path);
|
||||
std::env::set_current_dir(path)?;
|
||||
}
|
||||
|
||||
// Bind & start our server
|
||||
let server = DistantServer::bind(
|
||||
addr,
|
||||
cmd.port,
|
||||
cmd.to_shutdown_after_duration(),
|
||||
cmd.max_msg_capacity as usize,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Print information about port, key, etc.
|
||||
println!(
|
||||
"DISTANT DATA -- {} {}",
|
||||
server.port(),
|
||||
server.to_unprotected_hex_auth_key()
|
||||
);
|
||||
|
||||
// For the child, we want to fully disconnect it from pipes, which we do now
|
||||
if is_forked {
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
|
||||
// Let our server run to completion
|
||||
server.wait().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,207 +0,0 @@
|
||||
use crate::{
|
||||
cli::opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
|
||||
core::{
|
||||
data::{Request, Response},
|
||||
net::{Transport, TransportReadHalf, TransportWriteHalf},
|
||||
session::Session,
|
||||
state::ServerState,
|
||||
utils,
|
||||
},
|
||||
ExitCode, ExitCodeError,
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use fork::{daemon, Fork};
|
||||
use log::*;
|
||||
use orion::aead::SecretKey;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::{
|
||||
io,
|
||||
net::{tcp, TcpListener},
|
||||
runtime::Handle,
|
||||
sync::{mpsc, Mutex},
|
||||
};
|
||||
|
||||
mod handler;
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum Error {
|
||||
ConvertToIpAddrError(ConvertToIpAddrError),
|
||||
ForkError,
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
impl ExitCodeError for Error {
|
||||
fn to_exit_code(&self) -> ExitCode {
|
||||
match self {
|
||||
Self::ConvertToIpAddrError(_) => ExitCode::NoHost,
|
||||
Self::ForkError => ExitCode::OsErr,
|
||||
Self::IoError(x) => x.to_exit_code(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
if cmd.daemon {
|
||||
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
|
||||
match daemon(false, true) {
|
||||
Ok(Fork::Child) => {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, true).await })?;
|
||||
}
|
||||
Ok(Fork::Parent(pid)) => {
|
||||
info!("[distant detached, pid = {}]", pid);
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
Err(_) => return Err(Error::ForkError),
|
||||
}
|
||||
} else {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async { run_async(cmd, opt, false).await })?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
|
||||
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
|
||||
let socket_addrs = cmd.port.make_socket_addrs(addr);
|
||||
let shutdown_after = cmd.to_shutdown_after_duration();
|
||||
|
||||
// If specified, change the current working directory of this program
|
||||
if let Some(path) = cmd.current_dir.as_ref() {
|
||||
debug!("Setting current directory to {:?}", path);
|
||||
std::env::set_current_dir(path)?;
|
||||
}
|
||||
|
||||
debug!("Binding to {} in range {}", addr, cmd.port);
|
||||
let listener = TcpListener::bind(socket_addrs.as_slice()).await?;
|
||||
|
||||
let port = listener.local_addr()?.port();
|
||||
debug!("Bound to port: {}", port);
|
||||
|
||||
// Print information about port, key, etc.
|
||||
let key = {
|
||||
let session = Session {
|
||||
host: "--".to_string(),
|
||||
port,
|
||||
auth_key: SecretKey::default(),
|
||||
};
|
||||
|
||||
println!("{}", session.to_unprotected_string());
|
||||
|
||||
Arc::new(session.into_auth_key())
|
||||
};
|
||||
|
||||
// For the child, we want to fully disconnect it from pipes, which we do now
|
||||
if is_forked {
|
||||
if let Err(_) = fork::close_fd() {
|
||||
return Err(Error::ForkError);
|
||||
}
|
||||
}
|
||||
|
||||
// Build our state for the server
|
||||
let state: Arc<Mutex<ServerState<SocketAddr>>> = Arc::new(Mutex::new(ServerState::default()));
|
||||
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
|
||||
|
||||
// Wait for a client connection, then spawn a new task to handle
|
||||
// receiving data from the client
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {match result {
|
||||
Ok((client, addr)) => {
|
||||
// Establish a proper connection via a handshake, discarding the connection otherwise
|
||||
let transport = match Transport::from_handshake(client, Some(Arc::clone(&key))).await {
|
||||
Ok(transport) => transport,
|
||||
Err(x) => {
|
||||
error!("<Client @ {}> Failed handshake: {}", addr, x);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Split the transport into read and write halves so we can handle input
|
||||
// and output concurrently
|
||||
let (t_read, t_write) = transport.into_split();
|
||||
let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize);
|
||||
let ct_2 = Arc::clone(&ct);
|
||||
|
||||
// Spawn a new task that loops to handle requests from the client
|
||||
tokio::spawn({
|
||||
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
|
||||
|
||||
let state = Arc::clone(&state);
|
||||
async move {
|
||||
ct_2.lock().await.increment();
|
||||
f.await;
|
||||
state.lock().await.cleanup_client(addr).await;
|
||||
ct_2.lock().await.decrement();
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a new task that loops to handle responses to the client
|
||||
tokio::spawn(async move { response_loop(addr, t_write, rx).await });
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Listener failed: {}", x);
|
||||
break;
|
||||
}
|
||||
}}
|
||||
_ = notify.notified() => {
|
||||
warn!("Reached shutdown timeout, so terminating");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Repeatedly reads in new requests, processes them, and sends their responses to the
|
||||
/// response loop
|
||||
async fn request_loop(
|
||||
addr: SocketAddr,
|
||||
state: Arc<Mutex<ServerState<SocketAddr>>>,
|
||||
mut transport: TransportReadHalf<tcp::OwnedReadHalf>,
|
||||
tx: mpsc::Sender<Response>,
|
||||
) {
|
||||
loop {
|
||||
match transport.receive::<Request>().await {
|
||||
Ok(Some(req)) => {
|
||||
debug!(
|
||||
"<Client @ {}> Received request of type{} {}",
|
||||
addr,
|
||||
if req.payload.len() > 1 { "s" } else { "" },
|
||||
req.to_payload_type_string()
|
||||
);
|
||||
|
||||
if let Err(x) = handler::process(addr, Arc::clone(&state), req, tx.clone()).await {
|
||||
error!("<Client @ {}> {}", addr, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
info!("<Client @ {}> Closed connection", addr);
|
||||
break;
|
||||
}
|
||||
Err(x) => {
|
||||
error!("<Client @ {}> {}", addr, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Repeatedly sends responses out over the wire
|
||||
async fn response_loop(
|
||||
addr: SocketAddr,
|
||||
mut transport: TransportWriteHalf<tcp::OwnedWriteHalf>,
|
||||
mut rx: mpsc::Receiver<Response>,
|
||||
) {
|
||||
while let Some(res) = rx.recv().await {
|
||||
if let Err(x) = transport.send(res).await {
|
||||
error!("<Client @ {}> {}", addr, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,197 @@
|
||||
use super::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
|
||||
use crate::core::{client::Session, net::DataStream};
|
||||
use std::{
|
||||
fmt::Write,
|
||||
io::{self, Cursor, Read},
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
mod data;
|
||||
pub use data::*;
|
||||
|
||||
/// Represents an LSP server process on a remote machine
|
||||
pub struct RemoteLspProcess {
|
||||
inner: RemoteProcess,
|
||||
pub stdin: Option<RemoteLspStdin>,
|
||||
pub stdout: Option<RemoteLspStdout>,
|
||||
pub stderr: Option<RemoteLspStderr>,
|
||||
}
|
||||
|
||||
impl RemoteLspProcess {
|
||||
/// Spawns the specified process on the remote machine using the given session, treating
|
||||
/// the process like an LSP server
|
||||
pub async fn spawn<T>(
|
||||
session: Session<T>,
|
||||
cmd: String,
|
||||
args: Vec<String>,
|
||||
) -> Result<Self, RemoteProcessError>
|
||||
where
|
||||
T: DataStream + 'static,
|
||||
{
|
||||
let mut inner = RemoteProcess::spawn(session, cmd, args).await?;
|
||||
let stdin = inner.stdin.take().map(RemoteLspStdin::new);
|
||||
let stdout = inner.stdout.take().map(RemoteLspStdout::new);
|
||||
let stderr = inner.stderr.take().map(RemoteLspStderr::new);
|
||||
|
||||
Ok(RemoteLspProcess {
|
||||
inner,
|
||||
stdin,
|
||||
stdout,
|
||||
stderr,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RemoteLspProcess {
|
||||
type Target = RemoteProcess;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for RemoteLspProcess {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a remote LSP process' standard input (stdin)
|
||||
pub struct RemoteLspStdin {
|
||||
inner: RemoteStdin,
|
||||
buf: Option<String>,
|
||||
}
|
||||
|
||||
impl RemoteLspStdin {
|
||||
pub fn new(inner: RemoteStdin) -> Self {
|
||||
Self { inner, buf: None }
|
||||
}
|
||||
|
||||
/// Writes data to the stdin of a specific remote process
|
||||
pub async fn write(&mut self, data: &str) -> io::Result<()> {
|
||||
let mut queue = Vec::new();
|
||||
|
||||
// Create or insert into our buffer
|
||||
match &mut self.buf {
|
||||
Some(buf) => buf.push_str(data),
|
||||
None => self.buf = Some(data.to_string()),
|
||||
}
|
||||
|
||||
// Read LSP messages from our internal buffer
|
||||
let buf = self.buf.take().unwrap();
|
||||
let mut cursor = Cursor::new(buf);
|
||||
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
|
||||
queue.push(data);
|
||||
}
|
||||
|
||||
// Keep remainder of string not processed as LSP message in buffer
|
||||
if (cursor.position() as usize) < cursor.get_ref().len() {
|
||||
let mut buf = String::new();
|
||||
cursor.read_to_string(&mut buf)?;
|
||||
self.buf = Some(buf);
|
||||
}
|
||||
|
||||
// Process and then send out each LSP message in our queue
|
||||
for mut data in queue {
|
||||
// Convert distant:// to file://
|
||||
data.mut_content().convert_distant_scheme_to_local();
|
||||
self.inner.write(&data.to_string()).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a remote LSP process' standard output (stdout)
|
||||
pub struct RemoteLspStdout {
|
||||
inner: RemoteStdout,
|
||||
buf: Option<String>,
|
||||
}
|
||||
|
||||
impl RemoteLspStdout {
|
||||
pub fn new(inner: RemoteStdout) -> Self {
|
||||
Self { inner, buf: None }
|
||||
}
|
||||
|
||||
pub async fn read(&mut self) -> io::Result<String> {
|
||||
let mut queue = Vec::new();
|
||||
let data = self.inner.read().await?;
|
||||
|
||||
// Create or insert into our buffer
|
||||
match &mut self.buf {
|
||||
Some(buf) => buf.push_str(&data),
|
||||
None => self.buf = Some(data),
|
||||
}
|
||||
|
||||
// Read LSP messages from our internal buffer
|
||||
let buf = self.buf.take().unwrap();
|
||||
let mut cursor = Cursor::new(buf);
|
||||
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
|
||||
queue.push(data);
|
||||
}
|
||||
|
||||
// Keep remainder of string not processed as LSP message in buffer
|
||||
if (cursor.position() as usize) < cursor.get_ref().len() {
|
||||
let mut buf = String::new();
|
||||
cursor.read_to_string(&mut buf)?;
|
||||
self.buf = Some(buf);
|
||||
}
|
||||
|
||||
// Process and then add each LSP message as output
|
||||
let mut out = String::new();
|
||||
for mut data in queue {
|
||||
// Convert file:// to distant://
|
||||
data.mut_content().convert_local_scheme_to_distant();
|
||||
write!(&mut out, "{}", data).unwrap();
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a remote LSP process' stderr
|
||||
pub struct RemoteLspStderr {
|
||||
inner: RemoteStderr,
|
||||
buf: Option<String>,
|
||||
}
|
||||
|
||||
impl RemoteLspStderr {
|
||||
pub fn new(inner: RemoteStderr) -> Self {
|
||||
Self { inner, buf: None }
|
||||
}
|
||||
|
||||
pub async fn read(&mut self) -> io::Result<String> {
|
||||
let mut queue = Vec::new();
|
||||
let data = self.inner.read().await?;
|
||||
|
||||
// Create or insert into our buffer
|
||||
match &mut self.buf {
|
||||
Some(buf) => buf.push_str(&data),
|
||||
None => self.buf = Some(data),
|
||||
}
|
||||
|
||||
// Read LSP messages from our internal buffer
|
||||
let buf = self.buf.take().unwrap();
|
||||
let mut cursor = Cursor::new(buf);
|
||||
while let Ok(data) = LspData::from_buf_reader(&mut cursor) {
|
||||
queue.push(data);
|
||||
}
|
||||
|
||||
// Keep remainder of string not processed as LSP message in buffer
|
||||
if (cursor.position() as usize) < cursor.get_ref().len() {
|
||||
let mut buf = String::new();
|
||||
cursor.read_to_string(&mut buf)?;
|
||||
self.buf = Some(buf);
|
||||
}
|
||||
|
||||
// Process and then add each LSP message as output
|
||||
let mut out = String::new();
|
||||
for mut data in queue {
|
||||
// Convert file:// to distant://
|
||||
data.mut_content().convert_local_scheme_to_distant();
|
||||
write!(&mut out, "{}", data).unwrap();
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
mod lsp;
|
||||
mod process;
|
||||
mod session;
|
||||
mod utils;
|
||||
|
||||
// TODO: Make wrappers around a connection to facilitate the types
|
||||
// of engagements
|
||||
//
|
||||
// 1. Command -> Single request/response through a future
|
||||
// 2. Proxy -> Does proc-run and waits until proc-done received,
|
||||
// exposing a sender for stdin and receivers for stdout/stderr,
|
||||
// and supporting a future await for completion with exit code
|
||||
// 3.
|
||||
pub use lsp::*;
|
||||
pub use process::{RemoteProcess, RemoteProcessError, RemoteStderr, RemoteStdin, RemoteStdout};
|
||||
pub use session::*;
|
@ -0,0 +1,286 @@
|
||||
use crate::core::{
|
||||
client::{utils, Session},
|
||||
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
|
||||
data::{Request, RequestData, Response, ResponseData},
|
||||
net::{DataStream, TransportError},
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use tokio::{
|
||||
io,
|
||||
sync::mpsc,
|
||||
task::{JoinError, JoinHandle},
|
||||
};
|
||||
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum RemoteProcessError {
|
||||
/// When the process receives an unexpected response
|
||||
BadResponse,
|
||||
|
||||
/// When attempting to relay stdout/stderr over channels, but the channels fail
|
||||
ChannelDead,
|
||||
|
||||
/// When process is unable to read stdout/stderr from the server
|
||||
/// fast enough, resulting in dropped data
|
||||
Overloaded,
|
||||
|
||||
/// When the communication over the wire has issues
|
||||
TransportError(TransportError),
|
||||
|
||||
/// When the stream of responses from the server closes without receiving
|
||||
/// an indicator of the process' exit status
|
||||
UnexpectedEof,
|
||||
|
||||
/// When attempting to wait on the remote process, but the internal task joining failed
|
||||
WaitFailed(JoinError),
|
||||
}
|
||||
|
||||
/// Represents a process on a remote machine
|
||||
pub struct RemoteProcess {
|
||||
/// Id of the process
|
||||
id: usize,
|
||||
|
||||
/// Task that forwards stdin to the remote process by bundling it as stdin requests
|
||||
req_task: JoinHandle<Result<(), RemoteProcessError>>,
|
||||
|
||||
/// Task that reads in new responses, which returns the success and optional
|
||||
/// exit code once the process has completed
|
||||
res_task: JoinHandle<Result<(bool, Option<i32>), RemoteProcessError>>,
|
||||
|
||||
/// Sender for stdin
|
||||
pub stdin: Option<RemoteStdin>,
|
||||
|
||||
/// Receiver for stdout
|
||||
pub stdout: Option<RemoteStdout>,
|
||||
|
||||
/// Receiver for stderr
|
||||
pub stderr: Option<RemoteStderr>,
|
||||
|
||||
/// Sender for kill events
|
||||
kill: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl RemoteProcess {
|
||||
/// Spawns the specified process on the remote machine using the given session
|
||||
pub async fn spawn<T>(
|
||||
mut session: Session<T>,
|
||||
cmd: String,
|
||||
args: Vec<String>,
|
||||
) -> Result<Self, RemoteProcessError>
|
||||
where
|
||||
T: DataStream + 'static,
|
||||
{
|
||||
let tenant = utils::new_tenant();
|
||||
|
||||
// Submit our run request and wait for a response
|
||||
let res = session
|
||||
.send(Request::new(
|
||||
tenant.as_str(),
|
||||
vec![RequestData::ProcRun { cmd, args }],
|
||||
))
|
||||
.await?;
|
||||
|
||||
// We expect a singular response back
|
||||
if res.payload.len() != 1 {
|
||||
return Err(RemoteProcessError::BadResponse);
|
||||
}
|
||||
|
||||
// Response should be proc starting
|
||||
let id = match res.payload.into_iter().next().unwrap() {
|
||||
ResponseData::ProcStart { id } => id,
|
||||
_ => return Err(RemoteProcessError::BadResponse),
|
||||
};
|
||||
|
||||
// Create channels for our stdin/stdout/stderr
|
||||
let (stdin_tx, stdin_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
|
||||
let (stdout_tx, stdout_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
|
||||
let (stderr_tx, stderr_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
|
||||
|
||||
// Now we spawn a task to handle future responses that are async
|
||||
// such as ProcStdout, ProcStderr, and ProcDone
|
||||
let stream = session.to_response_broadcast_stream();
|
||||
let res_task = tokio::spawn(async move {
|
||||
process_incoming_responses(id, stream, stdout_tx, stderr_tx).await
|
||||
});
|
||||
|
||||
// Spawn a task that takes stdin from our channel and forwards it to the remote process
|
||||
let (kill_tx, kill_rx) = mpsc::channel(1);
|
||||
let req_task = tokio::spawn(async move {
|
||||
process_outgoing_requests(tenant, id, session, stdin_rx, kill_rx).await
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
req_task,
|
||||
res_task,
|
||||
stdin: Some(RemoteStdin(stdin_tx)),
|
||||
stdout: Some(RemoteStdout(stdout_rx)),
|
||||
stderr: Some(RemoteStderr(stderr_rx)),
|
||||
kill: kill_tx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the id of the running process
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Waits for the process to terminate, returning the success status and an optional exit code
|
||||
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
|
||||
self.res_task.await?
|
||||
}
|
||||
|
||||
/// Aborts the process by forcing its response task to shutdown, which means that a call
|
||||
/// to `wait` will return an error. Note that this does **not** send a kill request, so if
|
||||
/// you want to be nice you should send the request before aborting.
|
||||
pub fn abort(&self) {
|
||||
self.req_task.abort();
|
||||
self.res_task.abort();
|
||||
}
|
||||
|
||||
/// Submits a kill request for the running process
|
||||
pub async fn kill(&mut self) -> Result<(), RemoteProcessError> {
|
||||
self.kill
|
||||
.send(())
|
||||
.await
|
||||
.map_err(|_| RemoteProcessError::ChannelDead)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a remote process' standard input (stdin)
|
||||
pub struct RemoteStdin(mpsc::Sender<String>);
|
||||
|
||||
impl RemoteStdin {
|
||||
/// Writes data to the stdin of a specific remote process
|
||||
pub async fn write(&mut self, data: impl Into<String>) -> io::Result<()> {
|
||||
self.0
|
||||
.send(data.into())
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a remote process' standard output (stdout)
|
||||
pub struct RemoteStdout(mpsc::Receiver<String>);
|
||||
|
||||
impl RemoteStdout {
|
||||
/// Retrieves the latest stdout for a specific remote process
|
||||
pub async fn read(&mut self) -> io::Result<String> {
|
||||
self.0
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to a remote process' stderr
|
||||
pub struct RemoteStderr(mpsc::Receiver<String>);
|
||||
|
||||
impl RemoteStderr {
|
||||
/// Retrieves the latest stderr for a specific remote process
|
||||
pub async fn read(&mut self) -> io::Result<String> {
|
||||
self.0
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function that loops, processing outgoing stdin requests to a remote process as well as
|
||||
/// supporting a kill request to terminate the remote process
|
||||
async fn process_outgoing_requests<T>(
|
||||
tenant: String,
|
||||
id: usize,
|
||||
mut session: Session<T>,
|
||||
mut stdin_rx: mpsc::Receiver<String>,
|
||||
mut kill_rx: mpsc::Receiver<()>,
|
||||
) -> Result<(), RemoteProcessError>
|
||||
where
|
||||
T: DataStream,
|
||||
{
|
||||
loop {
|
||||
tokio::select! {
|
||||
data = stdin_rx.recv() => {
|
||||
match data {
|
||||
Some(data) => session.fire(
|
||||
Request::new(
|
||||
tenant.as_str(),
|
||||
vec![RequestData::ProcStdin { id, data }]
|
||||
)
|
||||
).await?,
|
||||
None => break Err(RemoteProcessError::ChannelDead),
|
||||
}
|
||||
}
|
||||
msg = kill_rx.recv() => {
|
||||
if msg.is_some() {
|
||||
session
|
||||
.fire(Request::new(
|
||||
tenant.as_str(),
|
||||
vec![RequestData::ProcKill { id }],
|
||||
))
|
||||
.await?;
|
||||
break Ok(());
|
||||
} else {
|
||||
break Err(RemoteProcessError::ChannelDead);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function that loops, processing incoming stdout & stderr requests from a remote process
|
||||
async fn process_incoming_responses(
|
||||
proc_id: usize,
|
||||
mut stream: BroadcastStream<Response>,
|
||||
stdout_tx: mpsc::Sender<String>,
|
||||
stderr_tx: mpsc::Sender<String>,
|
||||
) -> Result<(bool, Option<i32>), RemoteProcessError> {
|
||||
let mut result = Err(RemoteProcessError::UnexpectedEof);
|
||||
|
||||
while let Some(res) = stream.next().await {
|
||||
match res {
|
||||
Ok(res) => {
|
||||
// Check if any of the payload data is the termination
|
||||
let exit_status = res.payload.iter().find_map(|data| match data {
|
||||
ResponseData::ProcDone { id, success, code } if *id == proc_id => {
|
||||
Some((*success, *code))
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
|
||||
// Next, check for stdout/stderr and send them along our channels
|
||||
// TODO: What should we do about unexpected data? For now, just ignore
|
||||
for data in res.payload {
|
||||
match data {
|
||||
ResponseData::ProcStdout { id, data } if id == proc_id => {
|
||||
if let Err(_) = stdout_tx.send(data).await {
|
||||
result = Err(RemoteProcessError::ChannelDead);
|
||||
break;
|
||||
}
|
||||
}
|
||||
ResponseData::ProcStderr { id, data } if id == proc_id => {
|
||||
if let Err(_) = stderr_tx.send(data).await {
|
||||
result = Err(RemoteProcessError::ChannelDead);
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// If we got a termination, then exit accordingly
|
||||
if let Some((success, code)) = exit_status {
|
||||
result = Ok((success, code));
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
result = Err(RemoteProcessError::Overloaded);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
use std::{future::Future, time::Duration};
|
||||
use tokio::{io, time};
|
||||
|
||||
// Generates a new tenant name
|
||||
pub fn new_tenant() -> String {
|
||||
format!("tenant_{}{}", rand::random::<u16>(), rand::random::<u8>())
|
||||
}
|
||||
|
||||
// Wraps a future in a tokio timeout call, transforming the error into
|
||||
// an io error
|
||||
pub async fn timeout<T, F>(d: Duration, f: F) -> io::Result<T>
|
||||
where
|
||||
F: Future<Output = T>,
|
||||
{
|
||||
time::timeout(d, f)
|
||||
.await
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
|
||||
}
|
@ -1,7 +1,5 @@
|
||||
pub mod client;
|
||||
pub mod constants;
|
||||
pub mod data;
|
||||
pub mod lsp;
|
||||
pub mod net;
|
||||
pub mod session;
|
||||
pub mod state;
|
||||
pub mod utils;
|
||||
pub mod server;
|
||||
|
@ -1,8 +1,5 @@
|
||||
mod transport;
|
||||
pub use transport::*;
|
||||
|
||||
mod client;
|
||||
pub use client::Client;
|
||||
|
||||
// Re-export commonly-used orion structs
|
||||
pub use orion::aead::SecretKey;
|
||||
|
@ -0,0 +1,222 @@
|
||||
mod handler;
|
||||
mod port;
|
||||
mod state;
|
||||
mod utils;
|
||||
|
||||
pub use port::{PortRange, PortRangeParseError};
|
||||
use state::State;
|
||||
|
||||
use crate::core::{
|
||||
data::{Request, Response},
|
||||
net::{SecretKey, Transport, TransportReadHalf, TransportWriteHalf},
|
||||
};
|
||||
use log::*;
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::{
|
||||
io,
|
||||
net::{tcp, TcpListener, TcpStream},
|
||||
runtime::Handle,
|
||||
sync::{mpsc, Mutex, Notify},
|
||||
task::{JoinError, JoinHandle},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
/// Represents a server that listens for requests, processes them, and sends responses
|
||||
pub struct Server {
|
||||
port: u16,
|
||||
state: Arc<Mutex<State<SocketAddr>>>,
|
||||
auth_key: Arc<SecretKey>,
|
||||
notify: Arc<Notify>,
|
||||
conn_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub async fn bind(
|
||||
addr: IpAddr,
|
||||
port: PortRange,
|
||||
shutdown_after: Option<Duration>,
|
||||
max_msg_capacity: usize,
|
||||
) -> io::Result<Self> {
|
||||
debug!("Binding to {} in range {}", addr, port);
|
||||
let listener = TcpListener::bind(port.make_socket_addrs(addr).as_slice()).await?;
|
||||
|
||||
let port = listener.local_addr()?.port();
|
||||
debug!("Bound to port: {}", port);
|
||||
|
||||
// Build our state for the server
|
||||
let state: Arc<Mutex<State<SocketAddr>>> = Arc::new(Mutex::new(State::default()));
|
||||
let auth_key = Arc::new(SecretKey::default());
|
||||
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
|
||||
|
||||
// Spawn our connection task
|
||||
let state_2 = Arc::clone(&state);
|
||||
let auth_key_2 = Arc::clone(&auth_key);
|
||||
let notify_2 = Arc::clone(¬ify);
|
||||
let conn_task = tokio::spawn(async move {
|
||||
connection_loop(
|
||||
listener,
|
||||
state_2,
|
||||
auth_key_2,
|
||||
ct,
|
||||
notify_2,
|
||||
max_msg_capacity,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
port,
|
||||
state,
|
||||
auth_key,
|
||||
notify,
|
||||
conn_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the port this server is bound to
|
||||
pub fn port(&self) -> u16 {
|
||||
self.port
|
||||
}
|
||||
|
||||
/// Returns a string representing the auth key as hex
|
||||
pub fn to_unprotected_hex_auth_key(&self) -> String {
|
||||
hex::encode(self.auth_key.unprotected_as_bytes())
|
||||
}
|
||||
|
||||
/// Waits for the server to terminate
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
self.conn_task.await
|
||||
}
|
||||
|
||||
/// Shutdown the server
|
||||
pub fn shutdown(&self) {
|
||||
self.notify.notify_one()
|
||||
}
|
||||
}
|
||||
|
||||
async fn connection_loop(
|
||||
listener: TcpListener,
|
||||
state: Arc<Mutex<State<SocketAddr>>>,
|
||||
auth_key: Arc<SecretKey>,
|
||||
tracker: Arc<Mutex<utils::ConnTracker>>,
|
||||
notify: Arc<Notify>,
|
||||
max_msg_capacity: usize,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {match result {
|
||||
Ok((conn, addr)) => {
|
||||
if let Err(x) = on_new_conn(
|
||||
conn,
|
||||
addr,
|
||||
Arc::clone(&state),
|
||||
Arc::clone(&auth_key),
|
||||
Arc::clone(&tracker),
|
||||
max_msg_capacity
|
||||
).await {
|
||||
error!("<Conn @ {}> Failed handshake: {}", addr, x);
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Listener failed: {}", x);
|
||||
break;
|
||||
}
|
||||
}}
|
||||
_ = notify.notified() => {
|
||||
warn!("Reached shutdown timeout, so terminating");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Processes a new connection, performing a handshake, and then spawning two tasks to handle
|
||||
/// input and output, returning join handles for the input and output tasks respectively
|
||||
async fn on_new_conn(
|
||||
conn: TcpStream,
|
||||
addr: SocketAddr,
|
||||
state: Arc<Mutex<State<SocketAddr>>>,
|
||||
auth_key: Arc<SecretKey>,
|
||||
tracker: Arc<Mutex<utils::ConnTracker>>,
|
||||
max_msg_capacity: usize,
|
||||
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)> {
|
||||
// Establish a proper connection via a handshake,
|
||||
// discarding the connection otherwise
|
||||
let transport = Transport::from_handshake(conn, Some(auth_key)).await?;
|
||||
|
||||
// Split the transport into read and write halves so we can handle input
|
||||
// and output concurrently
|
||||
let (t_read, t_write) = transport.into_split();
|
||||
let (tx, rx) = mpsc::channel(max_msg_capacity);
|
||||
let ct_2 = Arc::clone(&tracker);
|
||||
|
||||
// Spawn a new task that loops to handle requests from the client
|
||||
let req_task = tokio::spawn({
|
||||
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
|
||||
|
||||
let state = Arc::clone(&state);
|
||||
async move {
|
||||
ct_2.lock().await.increment();
|
||||
f.await;
|
||||
state.lock().await.cleanup_client(addr).await;
|
||||
ct_2.lock().await.decrement();
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a new task that loops to handle responses to the client
|
||||
let res_task = tokio::spawn(async move { response_loop(addr, t_write, rx).await });
|
||||
|
||||
Ok((req_task, res_task))
|
||||
}
|
||||
|
||||
/// Repeatedly reads in new requests, processes them, and sends their responses to the
|
||||
/// response loop
|
||||
async fn request_loop(
|
||||
addr: SocketAddr,
|
||||
state: Arc<Mutex<State<SocketAddr>>>,
|
||||
mut transport: TransportReadHalf<tcp::OwnedReadHalf>,
|
||||
tx: mpsc::Sender<Response>,
|
||||
) {
|
||||
loop {
|
||||
match transport.receive::<Request>().await {
|
||||
Ok(Some(req)) => {
|
||||
debug!(
|
||||
"<Conn @ {}> Received request of type{} {}",
|
||||
addr,
|
||||
if req.payload.len() > 1 { "s" } else { "" },
|
||||
req.to_payload_type_string()
|
||||
);
|
||||
|
||||
if let Err(x) = handler::process(addr, Arc::clone(&state), req, tx.clone()).await {
|
||||
error!("<Conn @ {}> {}", addr, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
info!("<Conn @ {}> Closed connection", addr);
|
||||
break;
|
||||
}
|
||||
Err(x) => {
|
||||
error!("<Conn @ {}> {}", addr, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Repeatedly sends responses out over the wire
|
||||
async fn response_loop(
|
||||
addr: SocketAddr,
|
||||
mut transport: TransportWriteHalf<tcp::OwnedWriteHalf>,
|
||||
mut rx: mpsc::Receiver<Response>,
|
||||
) {
|
||||
while let Some(res) = rx.recv().await {
|
||||
if let Err(x) = transport.send(res).await {
|
||||
error!("<Conn @ {}> {}", addr, x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
use derive_more::{Display, Error};
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
/// Represents some range of ports
|
||||
#[derive(Clone, Debug, Display, PartialEq, Eq)]
|
||||
#[display(
|
||||
fmt = "{}{}",
|
||||
start,
|
||||
"end.as_ref().map(|end| format!(\"[:{}]\", end)).unwrap_or_default()"
|
||||
)]
|
||||
pub struct PortRange {
|
||||
pub start: u16,
|
||||
pub end: Option<u16>,
|
||||
}
|
||||
|
||||
impl PortRange {
|
||||
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
|
||||
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
|
||||
let mut socket_addrs = Vec::new();
|
||||
let addr = addr.into();
|
||||
|
||||
for port in self.start..=self.end.unwrap_or(self.start) {
|
||||
socket_addrs.push(SocketAddr::from((addr, port)));
|
||||
}
|
||||
|
||||
socket_addrs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
|
||||
pub enum PortRangeParseError {
|
||||
InvalidPort,
|
||||
MissingPort,
|
||||
}
|
||||
|
||||
impl FromStr for PortRange {
|
||||
type Err = PortRangeParseError;
|
||||
|
||||
/// Parses PORT into single range or PORT1:PORTN into full range
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut tokens = s.trim().split(':');
|
||||
let start = tokens
|
||||
.next()
|
||||
.ok_or(PortRangeParseError::MissingPort)?
|
||||
.parse::<u16>()
|
||||
.map_err(|_| PortRangeParseError::InvalidPort)?;
|
||||
let end = if let Some(token) = tokens.next() {
|
||||
Some(
|
||||
token
|
||||
.parse::<u16>()
|
||||
.map_err(|_| PortRangeParseError::InvalidPort)?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if tokens.next().is_some() {
|
||||
return Err(PortRangeParseError::InvalidPort);
|
||||
}
|
||||
|
||||
Ok(Self { start, end })
|
||||
}
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
mod distant;
|
||||
pub use distant::{PortRange, PortRangeParseError, Server as DistantServer};
|
Loading…
Reference in New Issue