diff --git a/Cargo.lock b/Cargo.lock index a10ef13..0db7d84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,6 +140,7 @@ dependencies = [ "lazy_static", "log", "orion", + "rand", "serde", "serde_cbor", "serde_json", @@ -492,6 +493,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ppv-lite86" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -546,6 +553,46 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", + "rand_hc", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_hc" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7" +dependencies = [ + "rand_core", +] + [[package]] name = "redox_syscall" version = "0.2.9" @@ -803,6 +850,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index fea2f46..8c321d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,12 +18,13 @@ futures = "0.3.16" hex = "0.4.3" log = "0.4.14" orion = "0.16.0" +rand = "0.8.4" serde = { version = "1.0.126", features = ["derive"] } serde_cbor = "0.11.1" serde_json = "1.0.64" strum = { version = "0.21.0", features = ["derive"] } tokio = { version = "1.9.0", features = ["full"] } -tokio-stream = "0.1.7" +tokio-stream = { version = "0.1.7", features = ["sync"] } tokio-util = { version = "0.6.7", features = ["codec"] } # Binary-specific dependencies diff --git a/src/data.rs b/src/data.rs index 4303f46..84f7adc 100644 --- a/src/data.rs +++ b/src/data.rs @@ -3,16 +3,35 @@ use std::path::PathBuf; use structopt::StructOpt; use strum::AsRefStr; -/// Represents an operation to be performed on the remote machine +/// Represents the request to be performed on the remote machine +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub struct Request { + /// A unique id associated with the request + pub id: usize, + + /// The main payload containing the type and data of the request + pub payload: RequestPayload, +} + +impl From for Request { + /// Produces a new request with the given payload and a randomly-generated id + fn from(payload: RequestPayload) -> Self { + let id = rand::random(); + Self { id, payload } + } +} + +/// Represents the payload of a request to be performed on the remote machine #[derive(Clone, Debug, PartialEq, Eq, AsRefStr, StructOpt, Serialize, Deserialize)] #[serde( rename_all = "snake_case", deny_unknown_fields, tag = "type", - content = "payload" + content = "data" )] #[strum(serialize_all = "snake_case")] -pub enum Operation { +pub enum RequestPayload { /// Reads a file from the specified path on the remote machine #[structopt(visible_aliases = &["cat"])] FileRead { @@ -137,24 +156,44 @@ pub enum Operation { ProcList {}, } -/// Represents an response to an operation performed on the remote machine -#[derive(Clone, Debug, PartialEq, Eq, AsRefStr, Serialize, Deserialize)] -#[serde( - rename_all = "snake_case", - deny_unknown_fields, - tag = "status", - content = "payload" -)] -#[strum(serialize_all = "snake_case")] -pub enum Response { - /// Represents a successfully-handled operation - Ok(ResponsePayload), +/// Represents an response to a request performed on the remote machine +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub struct Response { + /// A unique id associated with the response + pub id: usize, - /// Represents an operation that failed - Error { - /// The message associated with the failure - msg: String, - }, + /// The id of the originating request, if there was one + /// (some responses are sent unprompted) + pub origin_id: Option, + + /// The main payload containing the type and data of the response + pub payload: ResponsePayload, +} + +impl Response { + /// Produces a new response with the given payload and origin id while supplying + /// randomly-generated id + pub fn from_payload_with_origin(payload: ResponsePayload, origin_id: usize) -> Self { + let id = rand::random(); + Self { + id, + origin_id: Some(origin_id), + payload, + } + } +} + +impl From for Response { + /// Produces a new response with the given payload, no origin id, and a randomly-generated id + fn from(payload: ResponsePayload) -> Self { + let id = rand::random(); + Self { + id, + origin_id: None, + payload, + } + } } /// Represents the payload of a successful response @@ -166,75 +205,36 @@ pub enum Response { content = "data" )] pub enum ResponsePayload { - /// Response to reading a file - FileRead { - /// The path to the file on the remote machine - path: PathBuf, + /// General okay with no extra data, returned in cases like + /// creating or removing a directory, copying a file, or renaming + /// a file + Ok, - /// Contents of the file - data: Vec, + /// General-purpose failure that occurred from some request + Error { + /// Details about the error + description: String, }, - /// Response to writing a file - FileWrite { - /// The path to the file on the remote machine - path: PathBuf, - - /// Total bytes written - bytes_written: usize, + /// Response containing some arbitrary, binary data + Blob { + /// Binary data associated with the response + data: Vec, }, - /// Response to appending to a file - FileAppend { - /// The path to the file on the remote machine - path: PathBuf, - + /// Response when some data was written on the remote machine + /// such as a file write or append + Written { /// Total bytes written bytes_written: usize, }, /// Response to reading a directory - DirRead { - /// The path to the directory on the remote machine - path: PathBuf, - - /// Entries contained within directory + DirEntries { + /// Entries contained within the requested directory entries: Vec, }, - /// Response to creating a directory - DirCreate { - /// The path to the directory on the remote machine - path: PathBuf, - }, - - /// Response to removing a directory - DirRemove { - /// The path to the directory on the remote machine - path: PathBuf, - - /// Total files & directories removed within the directory (0 if directory was empty) - total_removed: usize, - }, - - /// Response to copying a file/directory - Copy { - /// The path to the file/directory on the remote machine - src: PathBuf, - - /// New location on the remote machine for copy of file/directory - dst: PathBuf, - }, - - /// Response to moving/renaming a file/directory - Rename { - /// The path to the file/directory on the remote machine - src: PathBuf, - - /// New location on the remote machine for the file/directory - dst: PathBuf, - }, - /// Response to starting a new process ProcStart { /// Arbitrary id associated with running process diff --git a/src/net/mod.rs b/src/net/mod.rs index 92634aa..bc1c447 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -1,75 +1,102 @@ -use crate::utils::Session; -use codec::{DistantCodec, DistantCodecError}; -use derive_more::{Display, Error, From}; -use futures::SinkExt; -use orion::{ - aead::{self, SecretKey}, - errors::UnknownCryptoError, +mod transport; +pub use transport::{Transport, TransportError}; + +use crate::{ + data::{Request, Response, ResponsePayload}, + utils::Session, +}; +use log::*; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, }; -use serde::{de::DeserializeOwned, Serialize}; -use std::sync::Arc; -use tokio::{io, net::TcpStream}; -use tokio_stream::StreamExt; -use tokio_util::codec::Framed; +use tokio::{ + io, + sync::{oneshot, watch}, +}; +use tokio_stream::wrappers::WatchStream; -mod codec; +type Callbacks = Arc>>>; -#[derive(Debug, Display, Error, From)] -pub enum TransportError { - CodecError(DistantCodecError), - EncryptError(UnknownCryptoError), - IoError(io::Error), - SerializeError(serde_cbor::Error), -} +/// Represents a client that can make requests against a server +pub struct Client { + /// Underlying transport used by client + transport: Arc>, -/// Represents a transport of data across the network -pub struct Transport { - inner: Framed, - key: Arc, -} + /// Collection of callbacks to be invoked upon receiving a response to a request + callbacks: Callbacks, -impl Transport { - /// Wraps a `TcpStream` and associated credentials in a transport layer - pub fn new(stream: TcpStream, key: Arc) -> Self { - Self { - inner: Framed::new(stream, DistantCodec), - key, - } - } + /// Callback to trigger when a response is received without an origin or with an origin + /// not found in the list of callbacks + rx: watch::Receiver, +} +impl Client { /// Establishes a connection using the provided session pub async fn connect(session: Session) -> io::Result { - let stream = TcpStream::connect(session.to_socket_addr().await?).await?; - Ok(Self::new(stream, Arc::new(session.key))) - } + let transport = Arc::new(tokio::sync::Mutex::new(Transport::connect(session).await?)); + let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new())); + let (tx, rx) = watch::channel(Response::from(ResponsePayload::Error { + description: String::from("Fake server response"), + })); + + // Start a task that continually checks for responses and triggers callbacks + let transport_2 = Arc::clone(&transport); + let callbacks_2 = Arc::clone(&callbacks); + tokio::spawn(async move { + loop { + match transport_2.lock().await.receive::().await { + Ok(Some(res)) => { + let maybe_callback = res + .origin_id + .as_ref() + .and_then(|id| callbacks_2.lock().unwrap().remove(id)); + + // If there is an origin to this response, trigger the callback + if let Some(tx) = maybe_callback { + if let Err(res) = tx.send(res) { + error!("Failed to trigger callback for response {}", res.id); + } - /// Sends some data across the wire - pub async fn send(&mut self, data: T) -> Result<(), TransportError> { - // Serialize, encrypt, and then (TODO) sign - // NOTE: Cannot used packed implementation for now due to issues with deserialization - let data = serde_cbor::to_vec(&data)?; - let data = aead::seal(&self.key, &data)?; + // Otherwise, this goes into the junk draw of response handlers + } else { + if let Err(x) = tx.send(res) { + error!("Failed to trigger watch: {}", x); + } + } + } + Ok(None) => break, + Err(x) => { + error!("{}", x); + break; + } + } + } + }); - self.inner - .send(&data) - .await - .map_err(TransportError::CodecError) + Ok(Self { + transport, + callbacks, + rx, + }) } - /// Receives some data from out on the wire, waiting until it's available, - /// returning none if the transport is now closed - pub async fn receive(&mut self) -> Result, TransportError> { - // If data is received, we process like usual - if let Some(data) = self.inner.next().await { - // Validate (TODO) signature, decrypt, and then deserialize - let data = data?; - let data = aead::open(&self.key, &data)?; - let data = serde_cbor::from_slice(&data)?; - Ok(Some(data)) + /// Sends a request and waits for a response + pub async fn send(&self, req: Request) -> Result { + // First, add a callback that will trigger when we get the response for this request + let (tx, rx) = oneshot::channel(); + self.callbacks.lock().unwrap().insert(req.id, tx); + + // Second, send the request + self.transport.lock().await.send(req).await?; + + // Third, wait for the response + rx.await + .map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x))) + } - // Otherwise, if no data is received, this means that our socket has closed - } else { - Ok(None) - } + /// Creates and returns a new stream of responses that are received with no originating request + pub fn to_response_stream(&self) -> WatchStream { + WatchStream::new(self.rx.clone()) } } diff --git a/src/net/codec.rs b/src/net/transport/codec.rs similarity index 100% rename from src/net/codec.rs rename to src/net/transport/codec.rs diff --git a/src/net/transport/mod.rs b/src/net/transport/mod.rs new file mode 100644 index 0000000..92634aa --- /dev/null +++ b/src/net/transport/mod.rs @@ -0,0 +1,75 @@ +use crate::utils::Session; +use codec::{DistantCodec, DistantCodecError}; +use derive_more::{Display, Error, From}; +use futures::SinkExt; +use orion::{ + aead::{self, SecretKey}, + errors::UnknownCryptoError, +}; +use serde::{de::DeserializeOwned, Serialize}; +use std::sync::Arc; +use tokio::{io, net::TcpStream}; +use tokio_stream::StreamExt; +use tokio_util::codec::Framed; + +mod codec; + +#[derive(Debug, Display, Error, From)] +pub enum TransportError { + CodecError(DistantCodecError), + EncryptError(UnknownCryptoError), + IoError(io::Error), + SerializeError(serde_cbor::Error), +} + +/// Represents a transport of data across the network +pub struct Transport { + inner: Framed, + key: Arc, +} + +impl Transport { + /// Wraps a `TcpStream` and associated credentials in a transport layer + pub fn new(stream: TcpStream, key: Arc) -> Self { + Self { + inner: Framed::new(stream, DistantCodec), + key, + } + } + + /// Establishes a connection using the provided session + pub async fn connect(session: Session) -> io::Result { + let stream = TcpStream::connect(session.to_socket_addr().await?).await?; + Ok(Self::new(stream, Arc::new(session.key))) + } + + /// Sends some data across the wire + pub async fn send(&mut self, data: T) -> Result<(), TransportError> { + // Serialize, encrypt, and then (TODO) sign + // NOTE: Cannot used packed implementation for now due to issues with deserialization + let data = serde_cbor::to_vec(&data)?; + let data = aead::seal(&self.key, &data)?; + + self.inner + .send(&data) + .await + .map_err(TransportError::CodecError) + } + + /// Receives some data from out on the wire, waiting until it's available, + /// returning none if the transport is now closed + pub async fn receive(&mut self) -> Result, TransportError> { + // If data is received, we process like usual + if let Some(data) = self.inner.next().await { + // Validate (TODO) signature, decrypt, and then deserialize + let data = data?; + let data = aead::open(&self.key, &data)?; + let data = serde_cbor::from_slice(&data)?; + Ok(Some(data)) + + // Otherwise, if no data is received, this means that our socket has closed + } else { + Ok(None) + } + } +} diff --git a/src/opt.rs b/src/opt.rs index 7f324dd..1fedc32 100644 --- a/src/opt.rs +++ b/src/opt.rs @@ -1,4 +1,4 @@ -use crate::{subcommand, data::Operation}; +use crate::{subcommand, data::RequestPayload}; use derive_more::{Display, Error, From}; use lazy_static::lazy_static; use std::{ @@ -116,7 +116,7 @@ pub struct ExecuteSubcommand { pub format: ExecuteFormat, #[structopt(subcommand)] - pub operation: Operation, + pub operation: RequestPayload, } /// Represents options for binding a server to an IP address diff --git a/src/subcommand/execute.rs b/src/subcommand/execute.rs index 464e0c1..e4f7963 100644 --- a/src/subcommand/execute.rs +++ b/src/subcommand/execute.rs @@ -1,7 +1,7 @@ use crate::{ - data::Response, - net::{Transport, TransportError}, - opt::{CommonOpt, ExecuteSubcommand}, + data::{Request, Response}, + net::{Client, TransportError}, + opt::{CommonOpt, ExecuteFormat, ExecuteSubcommand}, utils::{Session, SessionError}, }; use derive_more::{Display, Error, From}; @@ -22,16 +22,20 @@ pub fn run(cmd: ExecuteSubcommand, opt: CommonOpt) -> Result<(), Error> { async fn run_async(cmd: ExecuteSubcommand, _opt: CommonOpt) -> Result<(), Error> { let session = Session::load().await?; - let mut transport = Transport::connect(session).await?; + let client = Client::connect(session).await?; - // Send our operation - transport.send(cmd.operation).await?; + let req = Request::from(cmd.operation); - // Continue to receive and process responses as long as we get them or we decide to end - loop { - let response = transport.receive::().await?; - println!("RESPONSE: {:?}", response); - } + let res = client.send(req).await?; + let res_string = match cmd.format { + ExecuteFormat::Json => serde_json::to_string(&res) + .map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?, + ExecuteFormat::Shell => format!("{:?}", res), + }; + println!("{}", res_string); + + // TODO: Process result to determine if we want to create a watch stream and continue + // to examine results Ok(()) } diff --git a/src/subcommand/listen.rs b/src/subcommand/listen.rs index 2a20f9e..2894cd0 100644 --- a/src/subcommand/listen.rs +++ b/src/subcommand/listen.rs @@ -1,5 +1,5 @@ use crate::{ - data::{Operation, Response}, + data::{Request, Response, ResponsePayload}, net::Transport, opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand}, }; @@ -90,17 +90,20 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R // Spawn a new task that loops to handle requests from the client tokio::spawn(async move { loop { - match transport.receive::().await { + match transport.receive::().await { Ok(Some(request)) => { trace!( " Received request of type {}", addr_string.as_str(), - request.as_ref() + request.payload.as_ref() ); - let response = Response::Error { - msg: String::from("Unimplemented"), - }; + let response = Response::from_payload_with_origin( + ResponsePayload::Error { + description: String::from("Unimplemented"), + }, + request.id, + ); if let Err(x) = transport.send(response).await { error!(" {}", addr_string.as_str(), x);