mirror of https://github.com/chipsenkbeil/distant
Refactor codebase to be more testable & add some initial tests
parent
1ca3cd7859
commit
ba6ebcfcb8
@ -0,0 +1,112 @@
|
||||
use crate::{
|
||||
cli,
|
||||
core::{
|
||||
client::{
|
||||
RemoteLspStderr, RemoteLspStdin, RemoteLspStdout, RemoteStderr, RemoteStdin,
|
||||
RemoteStdout,
|
||||
},
|
||||
constants::MAX_PIPE_CHUNK_SIZE,
|
||||
},
|
||||
};
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
thread,
|
||||
};
|
||||
use tokio::task::{JoinError, JoinHandle};
|
||||
|
||||
/// Represents a link between a remote process' stdin/stdout/stderr and this process'
|
||||
/// stdin/stdout/stderr
|
||||
pub struct RemoteProcessLink {
|
||||
_stdin_thread: thread::JoinHandle<()>,
|
||||
stdin_task: JoinHandle<io::Result<()>>,
|
||||
stdout_task: JoinHandle<io::Result<()>>,
|
||||
stderr_task: JoinHandle<io::Result<()>>,
|
||||
}
|
||||
|
||||
macro_rules! from_pipes {
|
||||
($stdin:expr, $stdout:expr, $stderr:expr) => {{
|
||||
let (stdin_thread, mut stdin_rx) = cli::stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE);
|
||||
let stdin_task = tokio::spawn(async move {
|
||||
loop {
|
||||
if let Some(input) = stdin_rx.recv().await {
|
||||
if let Err(x) = $stdin.write(input.as_str()).await {
|
||||
break Err(x);
|
||||
}
|
||||
} else {
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
});
|
||||
let stdout_task = tokio::spawn(async move {
|
||||
let handle = io::stdout();
|
||||
loop {
|
||||
match $stdout.read().await {
|
||||
Ok(output) => {
|
||||
let mut out = handle.lock();
|
||||
out.write_all(output.as_bytes())?;
|
||||
out.flush()?;
|
||||
}
|
||||
Err(x) => break Err(x),
|
||||
}
|
||||
}
|
||||
});
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let handle = io::stderr();
|
||||
loop {
|
||||
match $stderr.read().await {
|
||||
Ok(output) => {
|
||||
let mut out = handle.lock();
|
||||
out.write_all(output.as_bytes())?;
|
||||
out.flush()?;
|
||||
}
|
||||
Err(x) => break Err(x),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
RemoteProcessLink {
|
||||
_stdin_thread: stdin_thread,
|
||||
stdin_task,
|
||||
stdout_task,
|
||||
stderr_task,
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
impl RemoteProcessLink {
|
||||
/// Creates a new process link from the pipes of a remote process
|
||||
pub fn from_remote_pipes(
|
||||
mut stdin: RemoteStdin,
|
||||
mut stdout: RemoteStdout,
|
||||
mut stderr: RemoteStderr,
|
||||
) -> Self {
|
||||
from_pipes!(stdin, stdout, stderr)
|
||||
}
|
||||
|
||||
/// Creates a new process link from the pipes of a remote LSP server process
|
||||
pub fn from_remote_lsp_pipes(
|
||||
mut stdin: RemoteLspStdin,
|
||||
mut stdout: RemoteLspStdout,
|
||||
mut stderr: RemoteLspStderr,
|
||||
) -> Self {
|
||||
from_pipes!(stdin, stdout, stderr)
|
||||
}
|
||||
|
||||
/// Shuts down the link, aborting any running tasks, and swallowing join errors
|
||||
pub async fn shutdown(self) {
|
||||
self.abort();
|
||||
let _ = self.wait().await;
|
||||
}
|
||||
|
||||
/// Waits for the stdin, stdout, and stderr tasks to complete
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
tokio::try_join!(self.stdin_task, self.stdout_task, self.stderr_task).map(|_| ())
|
||||
}
|
||||
|
||||
/// Aborts the link by aborting tasks processing stdin, stdout, and stderr
|
||||
pub fn abort(&self) {
|
||||
self.stdin_task.abort();
|
||||
self.stdout_task.abort();
|
||||
self.stderr_task.abort();
|
||||
}
|
||||
}
|
@ -1,10 +1,13 @@
|
||||
mod buf;
|
||||
mod exit;
|
||||
mod link;
|
||||
mod opt;
|
||||
mod output;
|
||||
mod session;
|
||||
mod stdin;
|
||||
mod subcommand;
|
||||
|
||||
pub use exit::{ExitCode, ExitCodeError};
|
||||
pub use opt::*;
|
||||
pub use output::ResponseOut;
|
||||
pub use session::CliSession;
|
||||
|
@ -0,0 +1,43 @@
|
||||
use log::error;
|
||||
use std::{
|
||||
io::{self, BufReader, Read},
|
||||
thread,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Creates a new thread that performs stdin reads in a blocking fashion, returning
|
||||
/// a handle to the thread and a receiver that will be sent input as it becomes available
|
||||
pub fn spawn_channel(buffer: usize) -> (thread::JoinHandle<()>, mpsc::Receiver<String>) {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
|
||||
// NOTE: Using blocking I/O per tokio's advice to read from stdin line-by-line and then
|
||||
// pass the results to a separate async handler to forward to the remote process
|
||||
let handle = thread::spawn(move || {
|
||||
let mut stdin = BufReader::new(io::stdin());
|
||||
|
||||
// Maximum chunk that we expect to read at any one time
|
||||
let mut buf = vec![0; buffer];
|
||||
|
||||
loop {
|
||||
match stdin.read(&mut buf) {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => {
|
||||
match String::from_utf8(buf[..n].to_vec()) {
|
||||
Ok(text) => {
|
||||
if let Err(x) = tx.blocking_send(text) {
|
||||
error!("Stdin channel closed: {}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Input over stdin is invalid: {}", x);
|
||||
}
|
||||
}
|
||||
thread::yield_now();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(handle, rx)
|
||||
}
|
@ -0,0 +1,186 @@
|
||||
use crate::{
|
||||
cli::{
|
||||
link::RemoteProcessLink,
|
||||
opt::{ActionSubcommand, CommonOpt, SessionInput},
|
||||
CliSession, ExitCode, ExitCodeError, ResponseOut,
|
||||
},
|
||||
core::{
|
||||
client::{
|
||||
self, LspData, RemoteProcess, RemoteProcessError, Session, SessionInfo, SessionInfoFile,
|
||||
},
|
||||
data::{Request, RequestData},
|
||||
net::{DataStream, TransportError},
|
||||
},
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum Error {
|
||||
#[display(fmt = "Process failed with exit code: {}", _0)]
|
||||
BadProcessExit(#[error(not(source))] i32),
|
||||
IoError(io::Error),
|
||||
#[display(fmt = "Non-interactive but no operation supplied")]
|
||||
MissingOperation,
|
||||
RemoteProcessError(RemoteProcessError),
|
||||
TransportError(TransportError),
|
||||
}
|
||||
|
||||
impl ExitCodeError for Error {
|
||||
fn to_exit_code(&self) -> ExitCode {
|
||||
match self {
|
||||
Self::BadProcessExit(x) => ExitCode::Custom(*x),
|
||||
Self::IoError(x) => x.to_exit_code(),
|
||||
Self::MissingOperation => ExitCode::Usage,
|
||||
Self::RemoteProcessError(x) => x.to_exit_code(),
|
||||
Self::TransportError(x) => x.to_exit_code(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
|
||||
rt.block_on(async { run_async(cmd, opt).await })
|
||||
}
|
||||
|
||||
async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
let timeout = opt.to_timeout_duration();
|
||||
|
||||
match cmd.session {
|
||||
SessionInput::Environment => {
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
SessionInput::File => {
|
||||
let path = cmd.session_data.session_file.clone();
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(
|
||||
SessionInfoFile::load_from(path).await?.into(),
|
||||
timeout,
|
||||
)
|
||||
.await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
SessionInput::Pipe => {
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
SessionInput::Lsp => {
|
||||
let mut data =
|
||||
LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?;
|
||||
let info = data.take_session_info().map_err(io::Error::from)?;
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(info, timeout).await?,
|
||||
timeout,
|
||||
Some(data),
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(unix)]
|
||||
SessionInput::Socket => {
|
||||
let path = cmd.session_data.session_socket.clone();
|
||||
start(
|
||||
cmd,
|
||||
Session::unix_connect_timeout(path, None, timeout).await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn start<T>(
|
||||
cmd: ActionSubcommand,
|
||||
mut session: Session<T>,
|
||||
timeout: Duration,
|
||||
lsp_data: Option<LspData>,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: DataStream + 'static,
|
||||
{
|
||||
// TODO: Because lsp is being handled in a separate action, we should fail if we get
|
||||
// a session type of lsp for a regular action
|
||||
match (cmd.interactive, cmd.operation) {
|
||||
// ProcRun request is specially handled and we ignore interactive as
|
||||
// the stdin will be used for sending ProcStdin to remote process
|
||||
(_, Some(RequestData::ProcRun { cmd, args })) => {
|
||||
let mut proc = RemoteProcess::spawn(client::new_tenant(), session, cmd, args).await?;
|
||||
|
||||
// If we also parsed an LSP's initialize request for its session, we want to forward
|
||||
// it along in the case of a process call
|
||||
if let Some(data) = lsp_data {
|
||||
proc.stdin.as_mut().unwrap().write(data.to_string()).await?;
|
||||
}
|
||||
|
||||
// Now, map the remote process' stdin/stdout/stderr to our own process
|
||||
let link = RemoteProcessLink::from_remote_pipes(
|
||||
proc.stdin.take().unwrap(),
|
||||
proc.stdout.take().unwrap(),
|
||||
proc.stderr.take().unwrap(),
|
||||
);
|
||||
|
||||
let (success, exit_code) = proc.wait().await?;
|
||||
|
||||
// Shut down our link
|
||||
link.shutdown().await;
|
||||
|
||||
if !success {
|
||||
if let Some(code) = exit_code {
|
||||
return Err(Error::BadProcessExit(code));
|
||||
} else {
|
||||
return Err(Error::BadProcessExit(1));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// All other requests without interactive are oneoffs
|
||||
(false, Some(data)) => {
|
||||
let res = session
|
||||
.send_timeout(Request::new(client::new_tenant(), vec![data]), timeout)
|
||||
.await?;
|
||||
ResponseOut::new(cmd.format, res)?.print();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Interactive mode will send an optional first request and then continue
|
||||
// to read stdin to send more
|
||||
(true, maybe_req) => {
|
||||
// Send our first request if provided
|
||||
if let Some(data) = maybe_req {
|
||||
let res = session
|
||||
.send_timeout(Request::new(client::new_tenant(), vec![data]), timeout)
|
||||
.await?;
|
||||
ResponseOut::new(cmd.format, res)?.print();
|
||||
}
|
||||
|
||||
// Enter into CLI session where we receive requests on stdin and send out
|
||||
// over stdout/stderr
|
||||
let cli_session = CliSession::new(client::new_tenant(), session, cmd.format);
|
||||
cli_session.wait().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Not interactive and no operation given
|
||||
(false, None) => Err(Error::MissingOperation),
|
||||
}
|
||||
}
|
@ -1,182 +0,0 @@
|
||||
use crate::{
|
||||
cli::opt::Format,
|
||||
core::{
|
||||
constants::MAX_PIPE_CHUNK_SIZE,
|
||||
data::{Error, Request, RequestData, Response, ResponseData},
|
||||
net::{Client, DataStream},
|
||||
utils::StringBuf,
|
||||
},
|
||||
};
|
||||
use derive_more::IsVariant;
|
||||
use log::*;
|
||||
use std::marker::Unpin;
|
||||
use structopt::StructOpt;
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite},
|
||||
sync::{
|
||||
mpsc,
|
||||
oneshot::{self, error::TryRecvError},
|
||||
},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, IsVariant)]
|
||||
pub enum LoopConfig {
|
||||
Json,
|
||||
Proc { id: usize },
|
||||
Shell,
|
||||
}
|
||||
|
||||
impl From<LoopConfig> for Format {
|
||||
fn from(config: LoopConfig) -> Self {
|
||||
match config {
|
||||
LoopConfig::Json => Self::Json,
|
||||
LoopConfig::Proc { .. } | LoopConfig::Shell => Self::Shell,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts a new action loop that processes requests and receives responses
|
||||
///
|
||||
/// id represents the id of a remote process
|
||||
pub async fn interactive_loop<T>(
|
||||
mut client: Client<T>,
|
||||
tenant: String,
|
||||
config: LoopConfig,
|
||||
) -> io::Result<()>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + DataStream + Unpin + 'static,
|
||||
{
|
||||
let mut stream = client.to_response_broadcast_stream();
|
||||
|
||||
// Create a channel that can report when we should stop the loop based on a received request
|
||||
let (tx_stop, mut rx_stop) = oneshot::channel::<()>();
|
||||
|
||||
// We also want to spawn a task to handle sending stdin to the remote process
|
||||
let mut rx = spawn_stdin_reader();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = StringBuf::new();
|
||||
|
||||
while let Some(data) = rx.recv().await {
|
||||
match config {
|
||||
// Special exit condition for interactive format
|
||||
_ if buf.trim() == "exit" => {
|
||||
if let Err(_) = tx_stop.send(()) {
|
||||
error!("Failed to close interactive loop!");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// For json format, all stdin is treated as individual requests
|
||||
LoopConfig::Json => {
|
||||
buf.push_str(&data);
|
||||
let (lines, new_buf) = buf.into_full_lines();
|
||||
buf = new_buf;
|
||||
|
||||
// For each complete line, parse it as json and
|
||||
if let Some(lines) = lines {
|
||||
for data in lines.lines() {
|
||||
debug!("Client sending request: {:?}", data);
|
||||
let result = serde_json::from_str(&data)
|
||||
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x));
|
||||
match result {
|
||||
Ok(req) => match client.send(req).await {
|
||||
Ok(res) => match format_response(Format::Json, res) {
|
||||
Ok(out) => out.print(),
|
||||
Err(x) => error!("Failed to format response: {}", x),
|
||||
},
|
||||
Err(x) => {
|
||||
error!("Failed to send request: {}", x)
|
||||
}
|
||||
},
|
||||
Err(x) => {
|
||||
error!("Failed to serialize request ('{}'): {}", data, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For interactive shell format, parse stdin as individual commands
|
||||
LoopConfig::Shell => {
|
||||
buf.push_str(&data);
|
||||
let (lines, new_buf) = buf.into_full_lines();
|
||||
buf = new_buf;
|
||||
|
||||
if let Some(lines) = lines {
|
||||
for data in lines.lines() {
|
||||
trace!("Shell processing line: {:?}", data);
|
||||
if data.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
debug!("Client sending command: {:?}", data);
|
||||
|
||||
// NOTE: We have to stick something in as the first argument as clap/structopt
|
||||
// expect the binary name as the first item in the iterator
|
||||
let result = RequestData::from_iter_safe(
|
||||
std::iter::once("distant")
|
||||
.chain(data.trim().split(' ').filter(|s| !s.trim().is_empty())),
|
||||
);
|
||||
match result {
|
||||
Ok(data) => {
|
||||
match client
|
||||
.send(Request::new(tenant.as_str(), vec![data]))
|
||||
.await
|
||||
{
|
||||
Ok(res) => match format_response(Format::Shell, res) {
|
||||
Ok(out) => out.print(),
|
||||
Err(x) => error!("Failed to format response: {}", x),
|
||||
},
|
||||
Err(x) => {
|
||||
error!("Failed to send request: {}", x)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(x) => {
|
||||
error!("Failed to parse command: {}", x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For non-interactive shell format, all stdin is treated as a proc's stdin
|
||||
LoopConfig::Proc { id } => {
|
||||
debug!("Client sending stdin: {:?}", data);
|
||||
let req =
|
||||
Request::new(tenant.as_str(), vec![RequestData::ProcStdin { id, data }]);
|
||||
let result = client.send(req).await;
|
||||
|
||||
if let Err(x) = result {
|
||||
error!("Failed to send stdin to remote process ({}): {}", id, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
while let Err(TryRecvError::Empty) = rx_stop.try_recv() {
|
||||
if let Some(res) = stream.next().await {
|
||||
let res = res.map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x))?;
|
||||
|
||||
// NOTE: If the loop is for a proxy process, we should assume that the payload
|
||||
// is all-or-nothing for the done check
|
||||
let done = config.is_proc() && res.payload.iter().any(|x| x.is_proc_done());
|
||||
|
||||
format_response(config.into(), res)?.print();
|
||||
|
||||
// If we aren't interactive but are just running a proc and
|
||||
// we've received the end of the proc, we should exit
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
|
||||
// If we have nothing else in our stream, we should also exit
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,221 +0,0 @@
|
||||
use crate::{
|
||||
cli::{
|
||||
opt::{ActionSubcommand, CommonOpt, Format, SessionInput},
|
||||
ExitCode, ExitCodeError,
|
||||
},
|
||||
core::{
|
||||
client::{LspData, Session, SessionInfo, SessionInfoFile},
|
||||
data::{Request, RequestData, ResponseData},
|
||||
net::{DataStream, TransportError},
|
||||
},
|
||||
};
|
||||
use derive_more::{Display, Error, From};
|
||||
use log::*;
|
||||
use tokio::{io, time::Duration};
|
||||
|
||||
pub(crate) mod inner;
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
pub enum Error {
|
||||
IoError(io::Error),
|
||||
TransportError(TransportError),
|
||||
|
||||
#[display(fmt = "Non-interactive but no operation supplied")]
|
||||
MissingOperation,
|
||||
}
|
||||
|
||||
impl ExitCodeError for Error {
|
||||
fn to_exit_code(&self) -> ExitCode {
|
||||
match self {
|
||||
Self::IoError(x) => x.to_exit_code(),
|
||||
Self::TransportError(x) => x.to_exit_code(),
|
||||
Self::MissingOperation => ExitCode::Usage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
|
||||
rt.block_on(async { run_async(cmd, opt).await })
|
||||
}
|
||||
|
||||
async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
|
||||
let timeout = opt.to_timeout_duration();
|
||||
|
||||
match cmd.session {
|
||||
SessionInput::Environment => {
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
SessionInput::File => {
|
||||
let path = cmd.session_data.session_file.clone();
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(
|
||||
SessionInfoFile::load_from(path).await?.into(),
|
||||
timeout,
|
||||
)
|
||||
.await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
SessionInput::Pipe => {
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
SessionInput::Lsp => {
|
||||
let mut data =
|
||||
LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?;
|
||||
let info = data.take_session_info().map_err(io::Error::from)?;
|
||||
start(
|
||||
cmd,
|
||||
Session::tcp_connect_timeout(info, timeout).await?,
|
||||
timeout,
|
||||
Some(data),
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(unix)]
|
||||
SessionInput::Socket => {
|
||||
let path = cmd.session_data.session_socket.clone();
|
||||
start(
|
||||
cmd,
|
||||
Session::unix_connect_timeout(path, None, timeout).await?,
|
||||
timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn start<T>(
|
||||
cmd: ActionSubcommand,
|
||||
mut session: Session<T>,
|
||||
timeout: Duration,
|
||||
lsp_data: Option<LspData>,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: DataStream + 'static,
|
||||
{
|
||||
// TODO: Because lsp is being handled in a separate action, we should fail if we get
|
||||
// a session type of lsp for a regular action
|
||||
match (cmd.interactive, cmd.operation) {
|
||||
// ProcRun request is specially handled and we ignore interactive as
|
||||
// the stdin will be used for sending ProcStdin to remote process
|
||||
(_, Some(RequestData::ProcRun { cmd, args })) => {}
|
||||
|
||||
// All other requests without interactive are oneoffs
|
||||
(false, Some(req)) => {
|
||||
let res = session.send_timeout(req, timeout).await?;
|
||||
}
|
||||
|
||||
// Interactive mode will send an optional first request and then continue
|
||||
// to read stdin to send more
|
||||
(true, maybe_req) => {}
|
||||
|
||||
// Not interactive and no operation given
|
||||
(false, None) => Err(Error::MissingOperation),
|
||||
}
|
||||
|
||||
// 1. Determine what type of engagement we're doing
|
||||
// a. Oneoff connection, request, response
|
||||
// b. ProcRun where we take over stdin, stdout, stderr to provide a remote
|
||||
// process experience
|
||||
// c. Lsp where we do the ProcRun stuff, but translate stdin before sending and
|
||||
// stdout before outputting
|
||||
// d. Interactive program
|
||||
//
|
||||
// 2. If we have a queued up operation, we need to perform it
|
||||
// a. For oneoff, this is the request of the oneoff
|
||||
// b. For Procrun, this is the request that starts the process
|
||||
// c. For Lsp, this is the request that starts the process
|
||||
// d. For interactive, this is an optional first request
|
||||
//
|
||||
// 3. If we are using LSP session mode, then we want to send the
|
||||
// ProcStdin request after our optional queued up operation
|
||||
// a. For oneoff, this doesn't make sense and we should fail
|
||||
// b. For ProcRun, we do this after the ProcStart
|
||||
// c. For Lsp, we do this after the ProcStart
|
||||
// d. For interactive, this doesn't make sense as we only support
|
||||
// JSON and shell command input, not LSP input, so this would
|
||||
// fail and we should fail early
|
||||
//
|
||||
// ** LSP would be its own action, which means we want to abstract the logic that feeds
|
||||
// into this start method such that it can also be used with lsp action
|
||||
|
||||
// Make up a tenant name
|
||||
let tenant = utils::new_tenant();
|
||||
|
||||
// Special conditions for continuing to process responses
|
||||
let mut is_proc_req = false;
|
||||
let mut proc_id = 0;
|
||||
|
||||
if let Some(req) = cmd
|
||||
.operation
|
||||
.map(|payload| Request::new(tenant.as_str(), vec![payload]))
|
||||
{
|
||||
// NOTE: We know that there is a single payload entry, so it's all-or-nothing
|
||||
is_proc_req = req.payload.iter().any(|x| x.is_proc_run());
|
||||
|
||||
debug!("Client sending request: {:?}", req);
|
||||
let res = session.send_timeout(req, timeout).await?;
|
||||
|
||||
// Store the spawned process id for using in sending stdin (if we spawned a proc)
|
||||
// NOTE: We can assume that there is a single payload entry in response to our single
|
||||
// entry in our request
|
||||
if let Some(ResponseData::ProcStart { id }) = res.payload.first() {
|
||||
proc_id = *id;
|
||||
}
|
||||
|
||||
inner::format_response(cmd.format, res)?.print();
|
||||
|
||||
// If we also parsed an LSP's initialize request for its session, we want to forward
|
||||
// it along in the case of a process call
|
||||
//
|
||||
// TODO: Do we need to do this somewhere else to apply to all possible ways an LSP
|
||||
// could be started?
|
||||
if let Some(data) = lsp_data {
|
||||
session
|
||||
.fire_timeout(
|
||||
Request::new(
|
||||
tenant.as_str(),
|
||||
vec![RequestData::ProcStdin {
|
||||
id: proc_id,
|
||||
data: data.to_string(),
|
||||
}],
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
// If we are executing a process, we want to continue interacting via stdin and receiving
|
||||
// results via stdout/stderr
|
||||
//
|
||||
// If we are interactive, we want to continue looping regardless
|
||||
if is_proc_req || cmd.interactive {
|
||||
let config = match cmd.format {
|
||||
Format::Json => inner::LoopConfig::Json,
|
||||
Format::Shell if cmd.interactive => inner::LoopConfig::Shell,
|
||||
Format::Shell => inner::LoopConfig::Proc { id: proc_id },
|
||||
};
|
||||
inner::interactive_loop(client, tenant, config).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
use super::DataStream;
|
||||
use std::{future::Future, pin::Pin};
|
||||
use tokio::{
|
||||
io,
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
|
||||
/// Represents a type that has a listen interface
|
||||
pub trait Listener: Send + Sync {
|
||||
type Conn: DataStream;
|
||||
|
||||
/// Async function that accepts a new connection, returning `Ok(Self::Conn)`
|
||||
/// upon receiving the next connection
|
||||
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
|
||||
where
|
||||
Self: Sync + 'a;
|
||||
}
|
||||
|
||||
impl Listener for TcpListener {
|
||||
type Conn = TcpStream;
|
||||
|
||||
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
|
||||
where
|
||||
Self: Sync + 'a,
|
||||
{
|
||||
async fn accept(_self: &TcpListener) -> io::Result<TcpStream> {
|
||||
_self.accept().await.map(|(stream, _)| stream)
|
||||
}
|
||||
|
||||
Box::pin(accept(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Listener for tokio::net::UnixListener {
|
||||
type Conn = tokio::net::UnixStream;
|
||||
|
||||
fn accept<'a>(&'a self) -> Pin<Box<dyn Future<Output = io::Result<Self::Conn>> + Send + 'a>>
|
||||
where
|
||||
Self: Sync + 'a,
|
||||
{
|
||||
async fn accept(_self: &tokio::net::UnixListener) -> io::Result<tokio::net::UnixStream> {
|
||||
_self.accept().await.map(|(stream, _)| stream)
|
||||
}
|
||||
|
||||
Box::pin(accept(self))
|
||||
}
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
use derive_more::{Display, Error};
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
/// Represents some range of ports
|
||||
#[derive(Clone, Debug, Display, PartialEq, Eq)]
|
||||
#[display(
|
||||
fmt = "{}{}",
|
||||
start,
|
||||
"end.as_ref().map(|end| format!(\"[:{}]\", end)).unwrap_or_default()"
|
||||
)]
|
||||
pub struct PortRange {
|
||||
pub start: u16,
|
||||
pub end: Option<u16>,
|
||||
}
|
||||
|
||||
impl PortRange {
|
||||
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
|
||||
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
|
||||
let mut socket_addrs = Vec::new();
|
||||
let addr = addr.into();
|
||||
|
||||
for port in self.start..=self.end.unwrap_or(self.start) {
|
||||
socket_addrs.push(SocketAddr::from((addr, port)));
|
||||
}
|
||||
|
||||
socket_addrs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq)]
|
||||
pub enum PortRangeParseError {
|
||||
InvalidPort,
|
||||
MissingPort,
|
||||
}
|
||||
|
||||
impl FromStr for PortRange {
|
||||
type Err = PortRangeParseError;
|
||||
|
||||
/// Parses PORT into single range or PORT1:PORTN into full range
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut tokens = s.trim().split(':');
|
||||
let start = tokens
|
||||
.next()
|
||||
.ok_or(PortRangeParseError::MissingPort)?
|
||||
.parse::<u16>()
|
||||
.map_err(|_| PortRangeParseError::InvalidPort)?;
|
||||
let end = if let Some(token) = tokens.next() {
|
||||
Some(
|
||||
token
|
||||
.parse::<u16>()
|
||||
.map_err(|_| PortRangeParseError::InvalidPort)?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if tokens.next().is_some() {
|
||||
return Err(PortRangeParseError::InvalidPort);
|
||||
}
|
||||
|
||||
Ok(Self { start, end })
|
||||
}
|
||||
}
|
@ -1,101 +0,0 @@
|
||||
use log::*;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
runtime::Handle,
|
||||
sync::{Mutex, Notify},
|
||||
time::{self, Instant},
|
||||
};
|
||||
|
||||
pub struct ConnTracker {
|
||||
time: Instant,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl ConnTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
time: Instant::now(),
|
||||
cnt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment(&mut self) {
|
||||
self.time = Instant::now();
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
pub fn decrement(&mut self) {
|
||||
if self.cnt > 0 {
|
||||
self.time = Instant::now();
|
||||
self.cnt -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn time_and_cnt(&self) -> (Instant, usize) {
|
||||
(self.time, self.cnt)
|
||||
}
|
||||
|
||||
pub fn has_exceeded_timeout(&self, duration: Duration) -> bool {
|
||||
self.cnt == 0 && self.time.elapsed() > duration
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a new task that continues to monitor the time since a
|
||||
/// connection on the server existed, shutting down the runtime
|
||||
/// if the time is exceeded
|
||||
pub fn new_shutdown_task(
|
||||
handle: Handle,
|
||||
duration: Option<Duration>,
|
||||
) -> (Arc<Mutex<ConnTracker>>, Arc<Notify>) {
|
||||
let ct = Arc::new(Mutex::new(ConnTracker::new()));
|
||||
let notify = Arc::new(Notify::new());
|
||||
|
||||
let ct_2 = Arc::clone(&ct);
|
||||
let notify_2 = Arc::clone(¬ify);
|
||||
if let Some(duration) = duration {
|
||||
handle.spawn(async move {
|
||||
loop {
|
||||
// Get the time since the last connection joined/left
|
||||
let (base_time, cnt) = ct_2.lock().await.time_and_cnt();
|
||||
|
||||
// If we have no connections left, we want to wait
|
||||
// until the remaining period has passed and then
|
||||
// verify that we still have no connections
|
||||
if cnt == 0 {
|
||||
// Get the time we should wait based on when the last connection
|
||||
// was dropped; this closes the gap in the case where we start
|
||||
// sometime later than exactly duration since the last check
|
||||
let next_time = base_time + duration;
|
||||
let wait_duration = next_time
|
||||
.checked_duration_since(Instant::now())
|
||||
.unwrap_or_default()
|
||||
+ Duration::from_millis(1);
|
||||
|
||||
// Wait until we've reached our desired duration since the
|
||||
// last connection was dropped
|
||||
time::sleep(wait_duration).await;
|
||||
|
||||
// If we do have a connection at this point, don't exit
|
||||
if !ct_2.lock().await.has_exceeded_timeout(duration) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, we now should exit, which we do by reporting
|
||||
debug!(
|
||||
"Shutdown time of {}s has been reached!",
|
||||
duration.as_secs_f32()
|
||||
);
|
||||
notify_2.notify_one();
|
||||
break;
|
||||
}
|
||||
|
||||
// Otherwise, we just wait the full duration as worst case
|
||||
// we'll have waited just about the time desired if right
|
||||
// after waiting starts the last connection is closed
|
||||
time::sleep(duration).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
(ct, notify)
|
||||
}
|
@ -1,2 +1,8 @@
|
||||
mod distant;
|
||||
pub use distant::{PortRange, PortRangeParseError, Server as DistantServer};
|
||||
mod port;
|
||||
mod relay;
|
||||
mod utils;
|
||||
|
||||
pub use self::distant::DistantServer;
|
||||
pub use port::PortRange;
|
||||
pub use relay::RelayServer;
|
||||
|
@ -0,0 +1,180 @@
|
||||
use derive_more::Display;
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
ops::RangeInclusive,
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
/// Represents some range of ports
|
||||
#[derive(Clone, Debug, Display, PartialEq, Eq)]
|
||||
#[display(
|
||||
fmt = "{}{}",
|
||||
start,
|
||||
"end.as_ref().map(|end| format!(\":{}\", end)).unwrap_or_default()"
|
||||
)]
|
||||
pub struct PortRange {
|
||||
pub start: u16,
|
||||
pub end: Option<u16>,
|
||||
}
|
||||
|
||||
impl PortRange {
|
||||
/// Builds a collection of `SocketAddr` instances from the port range and given ip address
|
||||
pub fn make_socket_addrs(&self, addr: impl Into<IpAddr>) -> Vec<SocketAddr> {
|
||||
let mut socket_addrs = Vec::new();
|
||||
let addr = addr.into();
|
||||
|
||||
for port in self {
|
||||
socket_addrs.push(SocketAddr::from((addr, port)));
|
||||
}
|
||||
|
||||
socket_addrs
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RangeInclusive<u16>> for PortRange {
|
||||
fn from(r: RangeInclusive<u16>) -> Self {
|
||||
let (start, end) = r.into_inner();
|
||||
Self {
|
||||
start,
|
||||
end: Some(end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a PortRange {
|
||||
type Item = u16;
|
||||
type IntoIter = RangeInclusive<u16>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.start..=self.end.unwrap_or(self.start)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for PortRange {
|
||||
type Item = u16;
|
||||
type IntoIter = RangeInclusive<u16>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.start..=self.end.unwrap_or(self.start)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PortRange {
|
||||
type Err = std::num::ParseIntError;
|
||||
|
||||
/// Parses PORT into single range or PORT1:PORTN into full range
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.split_once(':') {
|
||||
Some((start, end)) => Ok(Self {
|
||||
start: start.parse()?,
|
||||
end: Some(end.parse()?),
|
||||
}),
|
||||
None => Ok(Self {
|
||||
start: s.parse()?,
|
||||
end: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display_should_properly_reflect_port_range() {
|
||||
let p = PortRange {
|
||||
start: 100,
|
||||
end: None,
|
||||
};
|
||||
assert_eq!(p.to_string(), "100");
|
||||
|
||||
let p = PortRange {
|
||||
start: 100,
|
||||
end: Some(200),
|
||||
};
|
||||
assert_eq!(p.to_string(), "100:200");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_range_inclusive_should_map_to_port_range() {
|
||||
let p = PortRange::from(100..=200);
|
||||
assert_eq!(p.start, 100);
|
||||
assert_eq!(p.end, Some(200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_iterator_should_support_port_range() {
|
||||
let p = PortRange {
|
||||
start: 1,
|
||||
end: None,
|
||||
};
|
||||
assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1]);
|
||||
assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1]);
|
||||
|
||||
let p = PortRange {
|
||||
start: 1,
|
||||
end: Some(3),
|
||||
};
|
||||
assert_eq!((&p).into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
|
||||
assert_eq!(p.into_iter().collect::<Vec<u16>>(), vec![1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn make_socket_addrs_should_produce_a_socket_addr_per_port() {
|
||||
let ip_addr = "127.0.0.1".parse::<IpAddr>().unwrap();
|
||||
|
||||
let p = PortRange {
|
||||
start: 1,
|
||||
end: None,
|
||||
};
|
||||
assert_eq!(
|
||||
p.make_socket_addrs(ip_addr),
|
||||
vec![SocketAddr::new(ip_addr, 1)]
|
||||
);
|
||||
|
||||
let p = PortRange {
|
||||
start: 1,
|
||||
end: Some(3),
|
||||
};
|
||||
assert_eq!(
|
||||
p.make_socket_addrs(ip_addr),
|
||||
vec![
|
||||
SocketAddr::new(ip_addr, 1),
|
||||
SocketAddr::new(ip_addr, 2),
|
||||
SocketAddr::new(ip_addr, 3),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_should_fail_if_not_starting_with_number() {
|
||||
assert!("100a".parse::<PortRange>().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_should_fail_if_provided_end_port_that_is_not_a_number() {
|
||||
assert!("100:200a".parse::<PortRange>().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_should_be_able_to_properly_read_in_port_range() {
|
||||
let p: PortRange = "100".parse().unwrap();
|
||||
assert_eq!(
|
||||
p,
|
||||
PortRange {
|
||||
start: 100,
|
||||
end: None
|
||||
}
|
||||
);
|
||||
|
||||
let p: PortRange = "100:200".parse().unwrap();
|
||||
assert_eq!(
|
||||
p,
|
||||
PortRange {
|
||||
start: 100,
|
||||
end: Some(200)
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,336 @@
|
||||
use crate::core::{
|
||||
client::Session,
|
||||
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
|
||||
data::{Request, RequestData, Response, ResponseData},
|
||||
net::{DataStream, Listener, Transport, TransportReadHalf, TransportWriteHalf},
|
||||
server::utils::{ConnTracker, ShutdownTask},
|
||||
};
|
||||
use log::*;
|
||||
use std::{collections::HashMap, marker::Unpin, sync::Arc};
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite},
|
||||
sync::{broadcast, mpsc, oneshot, Mutex},
|
||||
task::{JoinError, JoinHandle},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
/// Represents a server that relays requests & responses between connections and the
|
||||
/// actual server
|
||||
pub struct RelayServer {
|
||||
forward_task: JoinHandle<()>,
|
||||
accept_task: JoinHandle<()>,
|
||||
conns: Arc<Mutex<HashMap<usize, Conn>>>,
|
||||
}
|
||||
|
||||
impl RelayServer {
|
||||
pub async fn initialize<T1, T2, L>(
|
||||
mut session: Session<T1>,
|
||||
listener: L,
|
||||
shutdown_after: Option<Duration>,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
T1: DataStream + 'static,
|
||||
T2: DataStream + Send + 'static,
|
||||
L: Listener<Conn = T2> + 'static,
|
||||
{
|
||||
// Get a copy of our session's broadcaster so we can have each connection
|
||||
// subscribe to it for new messages filtered by tenant
|
||||
debug!("Acquiring session broadcaster");
|
||||
let broadcaster = session.to_response_broadcaster();
|
||||
|
||||
// Spawn task to send to the server requests from connections
|
||||
debug!("Spawning request forwarding task");
|
||||
let (req_tx, mut req_rx) = mpsc::channel::<Request>(CLIENT_BROADCAST_CHANNEL_CAPACITY);
|
||||
let forward_task = tokio::spawn(async move {
|
||||
while let Some(req) = req_rx.recv().await {
|
||||
debug!(
|
||||
"Forwarding request of type{} {} to server",
|
||||
if req.payload.len() > 1 { "s" } else { "" },
|
||||
req.to_payload_type_string()
|
||||
);
|
||||
if let Err(x) = session.fire(req).await {
|
||||
error!("Session failed to send request: {:?}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after);
|
||||
let conns = Arc::new(Mutex::new(HashMap::new()));
|
||||
let conns_2 = Arc::clone(&conns);
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let inner = async move {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok(stream) => {
|
||||
let result = Conn::initialize(
|
||||
stream,
|
||||
req_tx.clone(),
|
||||
broadcaster.clone(),
|
||||
tracker.as_ref().map(Arc::clone),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(conn) => conns_2.lock().await.insert(conn.id(), conn),
|
||||
Err(x) => {
|
||||
error!("Failed to initialize connection: {}", x);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
}
|
||||
Err(x) => {
|
||||
debug!("Listener has closed: {}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = inner => {}
|
||||
_ = shutdown => {
|
||||
warn!("Reached shutdown timeout, so terminating");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
forward_task,
|
||||
accept_task,
|
||||
conns,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn wait(self) -> Result<(), JoinError> {
|
||||
match tokio::try_join!(self.forward_task, self.accept_task) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(x) => Err(x),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn abort(&self) {
|
||||
self.forward_task.abort();
|
||||
self.accept_task.abort();
|
||||
self.conns
|
||||
.lock()
|
||||
.await
|
||||
.values()
|
||||
.for_each(|conn| conn.abort());
|
||||
}
|
||||
}
|
||||
|
||||
struct Conn {
|
||||
id: usize,
|
||||
req_task: JoinHandle<()>,
|
||||
res_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
/// Represents state associated with a connection
|
||||
#[derive(Default)]
|
||||
struct ConnState {
|
||||
processes: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Conn {
|
||||
pub async fn initialize<T>(
|
||||
stream: T,
|
||||
req_tx: mpsc::Sender<Request>,
|
||||
res_broadcaster: broadcast::Sender<Response>,
|
||||
ct: Option<Arc<Mutex<ConnTracker>>>,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
T: DataStream + 'static,
|
||||
{
|
||||
// Create a unique id to associate with the connection since its address
|
||||
// is not guaranteed to have an identifiable string
|
||||
let id: usize = rand::random();
|
||||
|
||||
// Establish a proper connection via a handshake, discarding the connection otherwise
|
||||
let transport = Transport::from_handshake(stream, None).await.map_err(|x| {
|
||||
error!("<Conn @ {}> Failed handshake: {}", id, x);
|
||||
io::Error::new(io::ErrorKind::Other, x)
|
||||
})?;
|
||||
let (t_read, t_write) = transport.into_split();
|
||||
|
||||
// Used to alert our response task of the connection's tenant name
|
||||
// based on the first
|
||||
let (tenant_tx, tenant_rx) = oneshot::channel();
|
||||
|
||||
// Create a state we use to keep track of connection-specific data
|
||||
debug!("<Conn @ {}> Initializing internal state", id);
|
||||
let state = Arc::new(Mutex::new(ConnState::default()));
|
||||
|
||||
// Spawn task to continually receive responses from the session that
|
||||
// may or may not be relevant to the connection, which will filter
|
||||
// by tenant and then along any response that matches
|
||||
let res_rx = res_broadcaster.subscribe();
|
||||
let state_2 = Arc::clone(&state);
|
||||
let res_task = tokio::spawn(async move {
|
||||
handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await;
|
||||
});
|
||||
|
||||
// Spawn task to continually read requests from connection and forward
|
||||
// them along to be sent via the session
|
||||
let req_tx = req_tx.clone();
|
||||
let req_task = tokio::spawn(async move {
|
||||
if let Some(ct) = ct.as_ref() {
|
||||
ct.lock().await.increment();
|
||||
}
|
||||
handle_conn_incoming(id, state, t_read, tenant_tx, req_tx).await;
|
||||
if let Some(ct) = ct.as_ref() {
|
||||
ct.lock().await.decrement();
|
||||
}
|
||||
debug!("<Conn @ {}> Disconnected", id);
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
req_task,
|
||||
res_task,
|
||||
})
|
||||
}
|
||||
|
||||
/// Id associated with the connection
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Aborts the connection from the server side
|
||||
pub fn abort(&self) {
|
||||
self.req_task.abort();
|
||||
self.res_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Conn::Request -> Session::Fire
|
||||
async fn handle_conn_incoming<T>(
|
||||
conn_id: usize,
|
||||
state: Arc<Mutex<ConnState>>,
|
||||
mut reader: TransportReadHalf<T>,
|
||||
tenant_tx: oneshot::Sender<String>,
|
||||
req_tx: mpsc::Sender<Request>,
|
||||
) where
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
macro_rules! process_req {
|
||||
($on_success:expr; $done:expr) => {
|
||||
match reader.receive::<Request>().await {
|
||||
Ok(Some(req)) => {
|
||||
$on_success(&req);
|
||||
if let Err(x) = req_tx.send(req).await {
|
||||
error!(
|
||||
"Failed to pass along request received on unix socket: {:?}",
|
||||
x
|
||||
);
|
||||
$done;
|
||||
}
|
||||
}
|
||||
Ok(None) => $done,
|
||||
Err(x) => {
|
||||
error!("Failed to receive request from unix stream: {:?}", x);
|
||||
$done;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let mut tenant = None;
|
||||
|
||||
// NOTE: Have to acquire our first request outside our loop since the oneshot
|
||||
// sender of the tenant's name is consuming
|
||||
process_req!(
|
||||
|req: &Request| {
|
||||
tenant = Some(req.tenant.clone());
|
||||
if let Err(x) = tenant_tx.send(req.tenant.clone()) {
|
||||
error!("Failed to send along acquired tenant name: {:?}", x);
|
||||
return;
|
||||
}
|
||||
};
|
||||
return
|
||||
);
|
||||
|
||||
// Loop and process all additional requests
|
||||
loop {
|
||||
process_req!(|_| {}; break);
|
||||
}
|
||||
|
||||
// At this point, we have processed at least one request successfully
|
||||
// and should have the tenant populated. If we had a failure at the
|
||||
// beginning, we exit the function early via return.
|
||||
let tenant = tenant.unwrap();
|
||||
|
||||
// Perform cleanup if done by sending a request to kill each running process
|
||||
// debug!("Cleaning conn {} :: killing process {}", conn_id, id);
|
||||
if let Err(x) = req_tx
|
||||
.send(Request::new(
|
||||
tenant.clone(),
|
||||
state
|
||||
.lock()
|
||||
.await
|
||||
.processes
|
||||
.iter()
|
||||
.map(|id| RequestData::ProcKill { id: *id })
|
||||
.collect(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
error!("<Conn @ {}> Failed to send kill signals: {}", conn_id, x);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_conn_outgoing<T>(
|
||||
conn_id: usize,
|
||||
state: Arc<Mutex<ConnState>>,
|
||||
mut writer: TransportWriteHalf<T>,
|
||||
tenant_rx: oneshot::Receiver<String>,
|
||||
mut res_rx: broadcast::Receiver<Response>,
|
||||
) where
|
||||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
// We wait for the tenant to be identified by the first request
|
||||
// before processing responses to be sent back; this is easier
|
||||
// to implement and yields the same result as we would be dropping
|
||||
// all responses before we know the tenant
|
||||
if let Ok(tenant) = tenant_rx.await {
|
||||
debug!("Associated tenant {} with conn {}", tenant, conn_id);
|
||||
loop {
|
||||
match res_rx.recv().await {
|
||||
// Forward along responses that are for our connection
|
||||
Ok(res) if res.tenant == tenant => {
|
||||
debug!(
|
||||
"Conn {} being sent response of type{} {}",
|
||||
conn_id,
|
||||
if res.payload.len() > 1 { "s" } else { "" },
|
||||
res.to_payload_type_string(),
|
||||
);
|
||||
|
||||
// If a new process was started, we want to capture the id and
|
||||
// associate it with the connection
|
||||
let ids = res.payload.iter().filter_map(|x| match x {
|
||||
ResponseData::ProcStart { id } => Some(*id),
|
||||
_ => None,
|
||||
});
|
||||
for id in ids {
|
||||
debug!("Tracking proc {} for conn {}", id, conn_id);
|
||||
state.lock().await.processes.push(id);
|
||||
}
|
||||
|
||||
if let Err(x) = writer.send(res).await {
|
||||
error!("Failed to send response through unix connection: {}", x);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Skip responses that are not for our connection
|
||||
Ok(_) => {}
|
||||
Err(x) => {
|
||||
error!(
|
||||
"Conn {} failed to receive broadcast response: {}",
|
||||
conn_id, x
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,291 @@
|
||||
use futures::future::OptionFuture;
|
||||
use log::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
sync::Mutex,
|
||||
task::{JoinError, JoinHandle},
|
||||
time::{self, Instant},
|
||||
};
|
||||
|
||||
/// Task to keep track of a possible server shutdown based on connections
|
||||
pub struct ShutdownTask {
|
||||
task: JoinHandle<()>,
|
||||
tracker: Arc<Mutex<ConnTracker>>,
|
||||
}
|
||||
|
||||
impl Future for ShutdownTask {
|
||||
type Output = Result<(), JoinError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.task).poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ShutdownTask {
|
||||
/// Given an optional timeout, will either create the shutdown task or not,
|
||||
/// returning an optional future for the completion of the shutdown task
|
||||
/// alongside an optional connection tracker
|
||||
pub fn maybe_initialize(
|
||||
duration: Option<Duration>,
|
||||
) -> (OptionFuture<ShutdownTask>, Option<Arc<Mutex<ConnTracker>>>) {
|
||||
match duration {
|
||||
Some(duration) => {
|
||||
let task = Self::initialize(duration);
|
||||
let tracker = task.tracker();
|
||||
let task: OptionFuture<_> = Some(task).into();
|
||||
(task, Some(tracker))
|
||||
}
|
||||
None => (None.into(), None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a new task that continues to monitor the time since a
|
||||
/// connection on the server existed, reporting a shutdown to all listeners
|
||||
/// once the timeout is exceeded
|
||||
pub fn initialize(duration: Duration) -> Self {
|
||||
let tracker = Arc::new(Mutex::new(ConnTracker::new()));
|
||||
|
||||
let tracker_2 = Arc::clone(&tracker);
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
// Get the time since the last connection joined/left
|
||||
let (base_time, cnt) = tracker_2.lock().await.time_and_cnt();
|
||||
|
||||
// If we have no connections left, we want to wait
|
||||
// until the remaining period has passed and then
|
||||
// verify that we still have no connections
|
||||
if cnt == 0 {
|
||||
// Get the time we should wait based on when the last connection
|
||||
// was dropped; this closes the gap in the case where we start
|
||||
// sometime later than exactly duration since the last check
|
||||
let next_time = base_time + duration;
|
||||
let wait_duration = next_time
|
||||
.checked_duration_since(Instant::now())
|
||||
.unwrap_or_default()
|
||||
+ Duration::from_millis(1);
|
||||
|
||||
// Wait until we've reached our desired duration since the
|
||||
// last connection was dropped
|
||||
time::sleep(wait_duration).await;
|
||||
|
||||
// If we do have a connection at this point, don't exit
|
||||
if !tracker_2.lock().await.has_reached_timeout(duration) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, we now should exit, which we do by reporting
|
||||
debug!(
|
||||
"Shutdown time of {}s has been reached!",
|
||||
duration.as_secs_f32()
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
// Otherwise, we just wait the full duration as worst case
|
||||
// we'll have waited just about the time desired if right
|
||||
// after waiting starts the last connection is closed
|
||||
time::sleep(duration).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self { task, tracker }
|
||||
}
|
||||
|
||||
/// Produces a new copy of the connection tracker associated with the shutdown manager
|
||||
pub fn tracker(&self) -> Arc<Mutex<ConnTracker>> {
|
||||
Arc::clone(&self.tracker)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnTracker {
|
||||
time: Instant,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl ConnTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
time: Instant::now(),
|
||||
cnt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment(&mut self) {
|
||||
self.time = Instant::now();
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
pub fn decrement(&mut self) {
|
||||
if self.cnt > 0 {
|
||||
self.time = Instant::now();
|
||||
self.cnt -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn time_and_cnt(&self) -> (Instant, usize) {
|
||||
(self.time, self.cnt)
|
||||
}
|
||||
|
||||
fn has_reached_timeout(&self, duration: Duration) -> bool {
|
||||
self.cnt == 0 && self.time.elapsed() >= duration
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tsets {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_not_resolve_if_has_connection_regardless_of_time() {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
task.tracker().lock().await.increment();
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(15)).await;
|
||||
|
||||
assert!(
|
||||
futures::poll!(task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration() {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(15)).await;
|
||||
|
||||
assert!(
|
||||
futures::poll!(task).is_ready(),
|
||||
"Shutdown task unexpectedly pending"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_resolve_if_no_connection_for_minimum_duration_after_connection_removed(
|
||||
) {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
task.tracker().lock().await.increment();
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(15)).await;
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
task.tracker().lock().await.decrement();
|
||||
time::sleep(Duration::from_millis(15)).await;
|
||||
|
||||
assert!(
|
||||
futures::poll!(task).is_ready(),
|
||||
"Shutdown task unexpectedly pending"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_task_should_not_resolve_before_minimum_duration() {
|
||||
let mut task = ShutdownTask::initialize(Duration::from_millis(10));
|
||||
assert!(
|
||||
futures::poll!(&mut task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
|
||||
time::sleep(Duration::from_millis(5)).await;
|
||||
|
||||
assert!(
|
||||
futures::poll!(task).is_pending(),
|
||||
"Shutdown task unexpectedly completed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_update_time_when_incremented() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
let (old_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
|
||||
// Wait to ensure that the new time will be different
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
tracker.increment();
|
||||
let (new_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 1);
|
||||
assert!(new_time > old_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_update_time_when_decremented() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
tracker.increment();
|
||||
|
||||
let (old_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 1);
|
||||
|
||||
// Wait to ensure that the new time will be different
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
tracker.decrement();
|
||||
let (new_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
assert!(new_time > old_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_not_update_time_when_decremented_if_at_zero_already() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
let (old_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
|
||||
// Wait to ensure that the new time would be different if updated
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
tracker.decrement();
|
||||
let (new_time, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
assert!(new_time == old_time);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_report_timeout_reached_when_time_has_elapsed_and_no_connections() {
|
||||
let tracker = ConnTracker::new();
|
||||
let (_, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 0);
|
||||
|
||||
// Wait to ensure that the new time would be different if updated
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
assert!(tracker.has_reached_timeout(Duration::from_millis(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conn_tracker_should_not_report_timeout_reached_when_time_has_elapsed_but_has_connections() {
|
||||
let mut tracker = ConnTracker::new();
|
||||
tracker.increment();
|
||||
|
||||
let (_, cnt) = tracker.time_and_cnt();
|
||||
assert_eq!(cnt, 1);
|
||||
|
||||
// Wait to ensure that the new time would be different if updated
|
||||
thread::sleep(Duration::from_millis(1));
|
||||
|
||||
assert!(!tracker.has_reached_timeout(Duration::from_millis(1)));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue