Refactor code to have request/response format w/ client wrapper for easier processing

pull/38/head
Chip Senkbeil 3 years ago
parent f6fa3e606e
commit f2cce4aa34
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

48
Cargo.lock generated

@ -140,6 +140,7 @@ dependencies = [
"lazy_static",
"log",
"orion",
"rand",
"serde",
"serde_cbor",
"serde_json",
@ -492,6 +493,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "ppv-lite86"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
@ -546,6 +553,46 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_hc",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7"
dependencies = [
"getrandom",
]
[[package]]
name = "rand_hc"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7"
dependencies = [
"rand_core",
]
[[package]]
name = "redox_syscall"
version = "0.2.9"
@ -803,6 +850,7 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]

@ -18,12 +18,13 @@ futures = "0.3.16"
hex = "0.4.3"
log = "0.4.14"
orion = "0.16.0"
rand = "0.8.4"
serde = { version = "1.0.126", features = ["derive"] }
serde_cbor = "0.11.1"
serde_json = "1.0.64"
strum = { version = "0.21.0", features = ["derive"] }
tokio = { version = "1.9.0", features = ["full"] }
tokio-stream = "0.1.7"
tokio-stream = { version = "0.1.7", features = ["sync"] }
tokio-util = { version = "0.6.7", features = ["codec"] }
# Binary-specific dependencies

@ -3,16 +3,35 @@ use std::path::PathBuf;
use structopt::StructOpt;
use strum::AsRefStr;
/// Represents an operation to be performed on the remote machine
/// Represents the request to be performed on the remote machine
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct Request {
/// A unique id associated with the request
pub id: usize,
/// The main payload containing the type and data of the request
pub payload: RequestPayload,
}
impl From<RequestPayload> for Request {
/// Produces a new request with the given payload and a randomly-generated id
fn from(payload: RequestPayload) -> Self {
let id = rand::random();
Self { id, payload }
}
}
/// Represents the payload of a request to be performed on the remote machine
#[derive(Clone, Debug, PartialEq, Eq, AsRefStr, StructOpt, Serialize, Deserialize)]
#[serde(
rename_all = "snake_case",
deny_unknown_fields,
tag = "type",
content = "payload"
content = "data"
)]
#[strum(serialize_all = "snake_case")]
pub enum Operation {
pub enum RequestPayload {
/// Reads a file from the specified path on the remote machine
#[structopt(visible_aliases = &["cat"])]
FileRead {
@ -137,24 +156,44 @@ pub enum Operation {
ProcList {},
}
/// Represents an response to an operation performed on the remote machine
#[derive(Clone, Debug, PartialEq, Eq, AsRefStr, Serialize, Deserialize)]
#[serde(
rename_all = "snake_case",
deny_unknown_fields,
tag = "status",
content = "payload"
)]
#[strum(serialize_all = "snake_case")]
pub enum Response {
/// Represents a successfully-handled operation
Ok(ResponsePayload),
/// Represents an response to a request performed on the remote machine
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", deny_unknown_fields)]
pub struct Response {
/// A unique id associated with the response
pub id: usize,
/// Represents an operation that failed
Error {
/// The message associated with the failure
msg: String,
},
/// The id of the originating request, if there was one
/// (some responses are sent unprompted)
pub origin_id: Option<usize>,
/// The main payload containing the type and data of the response
pub payload: ResponsePayload,
}
impl Response {
/// Produces a new response with the given payload and origin id while supplying
/// randomly-generated id
pub fn from_payload_with_origin(payload: ResponsePayload, origin_id: usize) -> Self {
let id = rand::random();
Self {
id,
origin_id: Some(origin_id),
payload,
}
}
}
impl From<ResponsePayload> for Response {
/// Produces a new response with the given payload, no origin id, and a randomly-generated id
fn from(payload: ResponsePayload) -> Self {
let id = rand::random();
Self {
id,
origin_id: None,
payload,
}
}
}
/// Represents the payload of a successful response
@ -166,75 +205,36 @@ pub enum Response {
content = "data"
)]
pub enum ResponsePayload {
/// Response to reading a file
FileRead {
/// The path to the file on the remote machine
path: PathBuf,
/// General okay with no extra data, returned in cases like
/// creating or removing a directory, copying a file, or renaming
/// a file
Ok,
/// Contents of the file
data: Vec<u8>,
/// General-purpose failure that occurred from some request
Error {
/// Details about the error
description: String,
},
/// Response to writing a file
FileWrite {
/// The path to the file on the remote machine
path: PathBuf,
/// Total bytes written
bytes_written: usize,
/// Response containing some arbitrary, binary data
Blob {
/// Binary data associated with the response
data: Vec<u8>,
},
/// Response to appending to a file
FileAppend {
/// The path to the file on the remote machine
path: PathBuf,
/// Response when some data was written on the remote machine
/// such as a file write or append
Written {
/// Total bytes written
bytes_written: usize,
},
/// Response to reading a directory
DirRead {
/// The path to the directory on the remote machine
path: PathBuf,
/// Entries contained within directory
DirEntries {
/// Entries contained within the requested directory
entries: Vec<DirEntry>,
},
/// Response to creating a directory
DirCreate {
/// The path to the directory on the remote machine
path: PathBuf,
},
/// Response to removing a directory
DirRemove {
/// The path to the directory on the remote machine
path: PathBuf,
/// Total files & directories removed within the directory (0 if directory was empty)
total_removed: usize,
},
/// Response to copying a file/directory
Copy {
/// The path to the file/directory on the remote machine
src: PathBuf,
/// New location on the remote machine for copy of file/directory
dst: PathBuf,
},
/// Response to moving/renaming a file/directory
Rename {
/// The path to the file/directory on the remote machine
src: PathBuf,
/// New location on the remote machine for the file/directory
dst: PathBuf,
},
/// Response to starting a new process
ProcStart {
/// Arbitrary id associated with running process

@ -1,75 +1,102 @@
use crate::utils::Session;
use codec::{DistantCodec, DistantCodecError};
use derive_more::{Display, Error, From};
use futures::SinkExt;
use orion::{
aead::{self, SecretKey},
errors::UnknownCryptoError,
mod transport;
pub use transport::{Transport, TransportError};
use crate::{
data::{Request, Response, ResponsePayload},
utils::Session,
};
use log::*;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use tokio::{io, net::TcpStream};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
use tokio::{
io,
sync::{oneshot, watch},
};
use tokio_stream::wrappers::WatchStream;
mod codec;
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
#[derive(Debug, Display, Error, From)]
pub enum TransportError {
CodecError(DistantCodecError),
EncryptError(UnknownCryptoError),
IoError(io::Error),
SerializeError(serde_cbor::Error),
}
/// Represents a client that can make requests against a server
pub struct Client {
/// Underlying transport used by client
transport: Arc<tokio::sync::Mutex<Transport>>,
/// Represents a transport of data across the network
pub struct Transport {
inner: Framed<TcpStream, DistantCodec>,
key: Arc<SecretKey>,
}
/// Collection of callbacks to be invoked upon receiving a response to a request
callbacks: Callbacks,
impl Transport {
/// Wraps a `TcpStream` and associated credentials in a transport layer
pub fn new(stream: TcpStream, key: Arc<SecretKey>) -> Self {
Self {
inner: Framed::new(stream, DistantCodec),
key,
}
}
/// Callback to trigger when a response is received without an origin or with an origin
/// not found in the list of callbacks
rx: watch::Receiver<Response>,
}
impl Client {
/// Establishes a connection using the provided session
pub async fn connect(session: Session) -> io::Result<Self> {
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
Ok(Self::new(stream, Arc::new(session.key)))
}
let transport = Arc::new(tokio::sync::Mutex::new(Transport::connect(session).await?));
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = watch::channel(Response::from(ResponsePayload::Error {
description: String::from("Fake server response"),
}));
// Start a task that continually checks for responses and triggers callbacks
let transport_2 = Arc::clone(&transport);
let callbacks_2 = Arc::clone(&callbacks);
tokio::spawn(async move {
loop {
match transport_2.lock().await.receive::<Response>().await {
Ok(Some(res)) => {
let maybe_callback = res
.origin_id
.as_ref()
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
// If there is an origin to this response, trigger the callback
if let Some(tx) = maybe_callback {
if let Err(res) = tx.send(res) {
error!("Failed to trigger callback for response {}", res.id);
}
/// Sends some data across the wire
pub async fn send<T: Serialize>(&mut self, data: T) -> Result<(), TransportError> {
// Serialize, encrypt, and then (TODO) sign
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serde_cbor::to_vec(&data)?;
let data = aead::seal(&self.key, &data)?;
// Otherwise, this goes into the junk draw of response handlers
} else {
if let Err(x) = tx.send(res) {
error!("Failed to trigger watch: {}", x);
}
}
}
Ok(None) => break,
Err(x) => {
error!("{}", x);
break;
}
}
}
});
self.inner
.send(&data)
.await
.map_err(TransportError::CodecError)
Ok(Self {
transport,
callbacks,
rx,
})
}
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, TransportError> {
// If data is received, we process like usual
if let Some(data) = self.inner.next().await {
// Validate (TODO) signature, decrypt, and then deserialize
let data = data?;
let data = aead::open(&self.key, &data)?;
let data = serde_cbor::from_slice(&data)?;
Ok(Some(data))
/// Sends a request and waits for a response
pub async fn send(&self, req: Request) -> Result<Response, TransportError> {
// First, add a callback that will trigger when we get the response for this request
let (tx, rx) = oneshot::channel();
self.callbacks.lock().unwrap().insert(req.id, tx);
// Second, send the request
self.transport.lock().await.send(req).await?;
// Third, wait for the response
rx.await
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
}
// Otherwise, if no data is received, this means that our socket has closed
} else {
Ok(None)
}
/// Creates and returns a new stream of responses that are received with no originating request
pub fn to_response_stream(&self) -> WatchStream<Response> {
WatchStream::new(self.rx.clone())
}
}

@ -0,0 +1,75 @@
use crate::utils::Session;
use codec::{DistantCodec, DistantCodecError};
use derive_more::{Display, Error, From};
use futures::SinkExt;
use orion::{
aead::{self, SecretKey},
errors::UnknownCryptoError,
};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use tokio::{io, net::TcpStream};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
mod codec;
#[derive(Debug, Display, Error, From)]
pub enum TransportError {
CodecError(DistantCodecError),
EncryptError(UnknownCryptoError),
IoError(io::Error),
SerializeError(serde_cbor::Error),
}
/// Represents a transport of data across the network
pub struct Transport {
inner: Framed<TcpStream, DistantCodec>,
key: Arc<SecretKey>,
}
impl Transport {
/// Wraps a `TcpStream` and associated credentials in a transport layer
pub fn new(stream: TcpStream, key: Arc<SecretKey>) -> Self {
Self {
inner: Framed::new(stream, DistantCodec),
key,
}
}
/// Establishes a connection using the provided session
pub async fn connect(session: Session) -> io::Result<Self> {
let stream = TcpStream::connect(session.to_socket_addr().await?).await?;
Ok(Self::new(stream, Arc::new(session.key)))
}
/// Sends some data across the wire
pub async fn send<T: Serialize>(&mut self, data: T) -> Result<(), TransportError> {
// Serialize, encrypt, and then (TODO) sign
// NOTE: Cannot used packed implementation for now due to issues with deserialization
let data = serde_cbor::to_vec(&data)?;
let data = aead::seal(&self.key, &data)?;
self.inner
.send(&data)
.await
.map_err(TransportError::CodecError)
}
/// Receives some data from out on the wire, waiting until it's available,
/// returning none if the transport is now closed
pub async fn receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, TransportError> {
// If data is received, we process like usual
if let Some(data) = self.inner.next().await {
// Validate (TODO) signature, decrypt, and then deserialize
let data = data?;
let data = aead::open(&self.key, &data)?;
let data = serde_cbor::from_slice(&data)?;
Ok(Some(data))
// Otherwise, if no data is received, this means that our socket has closed
} else {
Ok(None)
}
}
}

@ -1,4 +1,4 @@
use crate::{subcommand, data::Operation};
use crate::{subcommand, data::RequestPayload};
use derive_more::{Display, Error, From};
use lazy_static::lazy_static;
use std::{
@ -116,7 +116,7 @@ pub struct ExecuteSubcommand {
pub format: ExecuteFormat,
#[structopt(subcommand)]
pub operation: Operation,
pub operation: RequestPayload,
}
/// Represents options for binding a server to an IP address

@ -1,7 +1,7 @@
use crate::{
data::Response,
net::{Transport, TransportError},
opt::{CommonOpt, ExecuteSubcommand},
data::{Request, Response},
net::{Client, TransportError},
opt::{CommonOpt, ExecuteFormat, ExecuteSubcommand},
utils::{Session, SessionError},
};
use derive_more::{Display, Error, From};
@ -22,16 +22,20 @@ pub fn run(cmd: ExecuteSubcommand, opt: CommonOpt) -> Result<(), Error> {
async fn run_async(cmd: ExecuteSubcommand, _opt: CommonOpt) -> Result<(), Error> {
let session = Session::load().await?;
let mut transport = Transport::connect(session).await?;
let client = Client::connect(session).await?;
// Send our operation
transport.send(cmd.operation).await?;
let req = Request::from(cmd.operation);
// Continue to receive and process responses as long as we get them or we decide to end
loop {
let response = transport.receive::<Response>().await?;
println!("RESPONSE: {:?}", response);
}
let res = client.send(req).await?;
let res_string = match cmd.format {
ExecuteFormat::Json => serde_json::to_string(&res)
.map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?,
ExecuteFormat::Shell => format!("{:?}", res),
};
println!("{}", res_string);
// TODO: Process result to determine if we want to create a watch stream and continue
// to examine results
Ok(())
}

@ -1,5 +1,5 @@
use crate::{
data::{Operation, Response},
data::{Request, Response, ResponsePayload},
net::Transport,
opt::{CommonOpt, ConvertToIpAddrError, ListenSubcommand},
};
@ -90,17 +90,20 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R
// Spawn a new task that loops to handle requests from the client
tokio::spawn(async move {
loop {
match transport.receive::<Operation>().await {
match transport.receive::<Request>().await {
Ok(Some(request)) => {
trace!(
"<Client @ {}> Received request of type {}",
addr_string.as_str(),
request.as_ref()
request.payload.as_ref()
);
let response = Response::Error {
msg: String::from("Unimplemented"),
};
let response = Response::from_payload_with_origin(
ResponsePayload::Error {
description: String::from("Unimplemented"),
},
request.id,
);
if let Err(x) = transport.send(response).await {
error!("<Client @ {}> {}", addr_string.as_str(), x);

Loading…
Cancel
Save