Unfinished

pull/146/head
Chip Senkbeil 2 years ago
parent 569661cd04
commit ad9f1ac05a
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,7 +1,5 @@
use crate::{
Codec, FramedTransport, IntoSplit, Transport, TransportRead, TransportWrite, Request,
Response, TypedAsyncRead, TypedAsyncWrite,
};
use crate::{FramedTransport, Interest, Request, Transport, UntypedResponse};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{
ops::{Deref, DerefMut},
@ -28,11 +26,8 @@ where
/// Used to send requests to a server
channel: Channel<T, U>,
/// Contains the task that is running to send requests to a server
request_task: JoinHandle<()>,
/// Contains the task that is running to receive responses from a server
response_task: JoinHandle<()>,
/// Contains the task that is running to send requests and receive responses from a server
task: JoinHandle<()>,
}
impl<T, U> Client<T, U>
@ -40,40 +35,83 @@ where
T: Send + Sync + Serialize,
U: Send + Sync + DeserializeOwned,
{
/// Initializes a client using the provided reader and writer
pub fn new<R, W>(mut writer: W, mut reader: R) -> io::Result<Self>
/// Initializes a client using the provided transport
pub fn new<V>(transport: V) -> io::Result<Self>
where
R: TypedAsyncRead<Response<U>> + Send + 'static,
W: TypedAsyncWrite<Request<T>> + Send + 'static,
V: Transport,
{
let post_office = Arc::new(PostOffice::default());
let weak_post_office = Arc::downgrade(&post_office);
let (tx, mut rx) = mpsc::channel::<Request<T>>(1);
// Do handshake with the server
// TODO: Support user configuration
let mut transport: FramedTransport<_, _> = todo!();
// Start a task that continually checks for responses and delivers them using the
// post office
let response_task = tokio::spawn(async move {
let task = tokio::spawn(async move {
loop {
match reader.read().await {
Ok(Some(res)) => {
// Try to send response to appropriate mailbox
// TODO: How should we handle false response? Did logging in past
post_office.deliver_response(res).await;
}
Ok(None) => {
break;
}
Err(_) => {
break;
let ready = transport
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.expect("Failed to examine ready state");
if ready.is_readable() {
match transport.try_read_frame() {
Ok(Some(frame)) => match UntypedResponse::from_slice(frame.as_item()) {
Ok(response) => {
match response.to_typed_response() {
Ok(response) => {
// Try to send response to appropriate mailbox
// TODO: How should we handle false response? Did logging in past
post_office.deliver_response(response).await;
}
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"Failed receiving {}",
String::from_utf8_lossy(&response.payload),
);
}
error!("Invalid response: {x}");
}
}
}
Err(x) => {
error!("Invalid response: {x}");
}
},
Ok(None) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => (),
Err(x) => {
error!("Failed to read next frame: {x}");
}
}
}
}
});
let (tx, mut rx) = mpsc::channel::<Request<T>>(1);
let request_task = tokio::spawn(async move {
while let Some(req) = rx.recv().await {
if writer.write(req).await.is_err() {
break;
if ready.is_writable() {
if let Ok(request) = rx.try_recv() {
match request.to_vec() {
Ok(data) => match transport.try_write_frame(data) {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => (),
Err(x) => error!("Send failed: {x}"),
},
Err(x) => {
error!("Unable to serialize outgoing request: {x}");
}
}
}
match transport.try_flush() {
Ok(()) => (),
Err(x) if x.kind() == io::ErrorKind::WouldBlock => (),
Err(x) => {
error!("Failed to flush outgoing data: {x}");
}
}
}
}
});
@ -83,11 +121,7 @@ where
post_office: weak_post_office,
};
Ok(Self {
channel,
request_task,
response_task,
})
Ok(Self { channel, task })
}
/// Convert into underlying channel
@ -103,18 +137,17 @@ where
/// Waits for the client to terminate, which results when the receiving end of the network
/// connection is closed (or the client is shutdown)
pub async fn wait(self) -> Result<(), JoinError> {
tokio::try_join!(self.request_task, self.response_task).map(|_| ())
self.task.await
}
/// Abort the client's current connection by forcing its tasks to abort
pub fn abort(&self) {
self.request_task.abort();
self.response_task.abort();
self.task.abort();
}
/// Returns true if client's underlying event processing has finished/terminated
pub fn is_finished(&self) -> bool {
self.request_task.is_finished() && self.response_task.is_finished()
self.task.is_finished()
}
}

@ -142,14 +142,14 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, FramedTransport, TypedAsyncRead, TypedAsyncWrite};
use crate::{Client, FramedTransport};
use std::time::Duration;
type TestClient = Client<u8, u8>;
#[tokio::test]
async fn mail_should_return_mailbox_that_receives_responses_until_transport_closes() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let (t1, mut t2) = FramedTransport::pair(100);
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
@ -184,7 +184,7 @@ mod tests {
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let (t1, mut t2) = FramedTransport::pair(100);
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
@ -200,7 +200,7 @@ mod tests {
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let (t1, mut t2) = FramedTransport::pair(100);
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
@ -210,15 +210,13 @@ mod tests {
x => panic!("Unexpected response: {:?}", x),
}
let _req = TypedAsyncRead::<Request<u8>>::read(&mut t2)
.await
.unwrap()
.unwrap();
let frame = t2.try_read_frame().unwrap().unwrap();
let _req: Request<u8> = Request::from_slice(&frame.as_item()).unwrap();
}
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = FramedTransport::make_test_pair();
let (t1, mut t2) = FramedTransport::pair(100);
let session: TestClient = Client::from_framed_transport(t1).unwrap();
let mut channel = session.clone_channel();
@ -228,9 +226,7 @@ mod tests {
x => panic!("Unexpected response: {:?}", x),
}
let _req = TypedAsyncRead::<Request<u8>>::read(&mut t2)
.await
.unwrap()
.unwrap();
let frame = t2.try_read_frame().unwrap().unwrap();
let _req: Request<u8> = Request::from_slice(&frame.as_item()).unwrap();
}
}

@ -1,4 +1,4 @@
use crate::{Client, Codec, FramedTransport, IntoSplit, UnixSocketTransport};
use crate::{Client, Codec, FramedTransport, UnixSocketTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, path::Path};

@ -1,4 +1,4 @@
use crate::{Client, Codec, FramedTransport, IntoSplit, WindowsPipeTransport};
use crate::{Client, Codec, FramedTransport, WindowsPipeTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{

@ -0,0 +1,55 @@
use crate::{BoxedCodec, FramedTransport, PlainCodec, Request, Response, Transport};
use serde::{Deserialize, Serialize};
use std::io;
/// Represents options that the server has available for a connection
#[derive(Serialize, Deserialize)]
struct ServerConnectionOptions {
/// Choices for encryption as string labels
pub encryption: Vec<String>,
/// Choices for compression as string labels
pub compression: Vec<String>,
}
/// Represents the choice that the client has made regarding server connection options
struct ClientConnectionChoice {
/// Selected encryption
pub encryption: String,
/// Selected compression
pub compression: String,
}
/// Performs the client-side of a handshake
pub async fn client_handshake<T>(transport: T) -> io::Result<FramedTransport<T, BoxedCodec>>
where
T: Transport,
{
let transport = FramedTransport::new(transport, PlainCodec::new());
// Wait for the server to send us choices for communication
let frame = transport.read_frame().await?.ok_or_else(|| {
io::Error::new(
io::ErrorKind::ConnectionAborted,
"Connection aborted before receiving server communication",
)
})?;
// Parse the frame as the request for the client
let request = Request::<ServerConnectionOptions>::from_slice(frame.as_item())?;
// Select an encryption and compression choice
let encryption = request.payload.encryption[0];
let compression = request.payload.compression[0];
// Respond back with choices
}
/// Performs the server-side of a handshake
pub async fn server_handshake<T>(transport: T) -> io::Result<FramedTransport<T, BoxedCodec>>
where
T: Transport,
{
let transport = FramedTransport::new(transport, PlainCodec::new());
}

@ -1,6 +1,7 @@
mod any;
/* mod auth;
mod client; */
// mod auth;
mod client;
mod handshake;
mod id;
mod key;
mod listener;
@ -11,8 +12,8 @@ mod transport;
mod utils;
pub use any::*;
/* pub use auth::*;
pub use client::*; */
// pub use auth::*;
pub use client::*;
pub use id::*;
pub use key::*;
pub use listener::*;

@ -1,5 +1,6 @@
use super::{parse_msg_pack_str, Id};
use crate::utils;
use derive_more::{Display, Error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{borrow::Cow, io};
@ -61,7 +62,7 @@ impl<T: schemars::JsonSchema> Response<T> {
}
/// Error encountered when attempting to parse bytes as an untyped response
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, Display, Error, PartialEq, Eq, Hash)]
pub enum UntypedResponseParseError {
/// When the bytes do not represent a response
WrongType,
@ -88,7 +89,7 @@ pub struct UntypedResponse<'a> {
impl<'a> UntypedResponse<'a> {
/// Attempts to convert an untyped request to a typed request
pub fn to_typed_request<T: DeserializeOwned>(&self) -> io::Result<Response<T>> {
pub fn to_typed_response<T: DeserializeOwned>(&self) -> io::Result<Response<T>> {
Ok(Response {
id: self.id.to_string(),
origin_id: self.origin_id.to_string(),

@ -228,35 +228,33 @@ where
if ready.is_readable() {
match transport.try_read_frame() {
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
Ok(request) => {
if log::log_enabled!(Level::Trace) {
trace!(
"[Conn {connection_id}] Receiving {}",
String::from_utf8_lossy(&request.payload),
);
Ok(request) => match request.to_typed_request() {
Ok(request) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx {
connection_id,
request,
reply: reply.clone(),
local_data: Arc::clone(&self.local_data),
};
self.server.on_request(ctx).await;
}
match request.to_typed_request() {
Ok(request) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx {
connection_id,
request,
reply: reply.clone(),
local_data: Arc::clone(&self.local_data),
};
self.server.on_request(ctx).await;
}
Err(x) => {
error!("[Conn {connection_id}] Invalid request: {x}");
Err(x) => {
if log::log_enabled!(Level::Trace) {
trace!(
"[Conn {connection_id}] Failed receiving {}",
String::from_utf8_lossy(&request.payload),
);
}
error!("[Conn {connection_id}] Invalid request: {x}");
}
}
},
Err(x) => {
error!("[Conn {connection_id}] Invalid request: {x}");
}
@ -330,7 +328,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{InmemoryTransport, MpscListener, ServerConfig};
use crate::{InmemoryTransport, MpscListener, Request, ServerConfig};
use async_trait::async_trait;
use std::time::Duration;
@ -356,14 +354,8 @@ mod tests {
fn make_listener(
buffer: usize,
) -> (
mpsc::Sender<(
MpscTransportWriteHalf<Response<String>>,
MpscTransportReadHalf<Request<u16>>,
)>,
MpscListener<(
MpscTransportWriteHalf<Response<String>>,
MpscTransportReadHalf<Request<u16>>,
)>,
mpsc::Sender<InmemoryTransport>,
MpscListener<InmemoryTransport>,
) {
MpscListener::channel(buffer)
}
@ -374,9 +366,8 @@ mod tests {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (mut transport, connection) =
InmemoryTypedTransport::<Request<u16>, Response<String>>::pair(100);
tx.send(connection.into_split())
let (mut transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
@ -384,11 +375,12 @@ mod tests {
.expect("Failed to start server");
transport
.write(Request::new(123))
.await
.try_write(&Request::new(123).to_vec().unwrap())
.expect("Failed to send request");
let response: Response<String> = transport.read().await.unwrap().unwrap();
let mut buf = [0u8; 1024];
let n = transport.try_read(&mut buf).unwrap();
let response: Response<String> = Response::from_slice(&buf[..n]).unwrap();
assert_eq!(response.payload, "hello");
}
@ -417,9 +409,8 @@ mod tests {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) =
InmemoryTypedTransport::<Request<u16>, Response<String>>::pair(100);
tx.send(connection.into_split())
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
@ -446,9 +437,8 @@ mod tests {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) =
InmemoryTypedTransport::<Request<u16>, Response<String>>::pair(100);
tx.send(connection.into_split())
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
@ -471,9 +461,8 @@ mod tests {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) =
InmemoryTypedTransport::<Request<u16>, Response<String>>::pair(100);
tx.send(connection.into_split())
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");

@ -155,6 +155,22 @@ where
}
}
/// Continues to invoke [`try_read_frame`] until a frame is successfully read, an error is
/// encountered that is not [`ErrorKind::WouldBlock`], or the underlying transport has closed.
///
/// [`try_read_frame`]: FramedTransport::try_read_frame
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub async fn read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
loop {
self.readable().await?;
match self.try_read_frame() {
Err(x) if x.kind() == io::ErrorKind::WouldBlock => continue,
x => return x,
}
}
}
/// Writes a `frame` of bytes by using the [`Codec`] tied to this transport.
///
/// This is accomplished by continually calling the inner transport's `try_write`. If 0 is
@ -173,6 +189,31 @@ where
// Attempt to write everything in our queue
self.try_flush()
}
/// Invokes [`try_write_frame`] followed by a continuous calls to [`try_flush`] until a frame
/// is successfully written, an error is encountered that is not [`ErrorKind::WouldBlock`], or
/// the underlying transport has closed.
///
/// [`try_write_frame`]: FramedTransport::try_write_frame
/// [`try_flush`]: FramedTransport::try_flush
/// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
pub async fn write_frame<'a>(&mut self, frame: impl Into<Frame<'a>>) -> io::Result<()> {
self.writeable().await?;
match self.try_write_frame(frame) {
// Would block, so continually try to flush until good to go
Err(x) if x.kind() == io::ErrorKind::WouldBlock => loop {
self.writeable().await?;
match self.try_flush() {
Err(x) if x.kind() == io::ErrorKind::WouldBlock => continue,
x => return x,
}
},
// Already fully succeeded or failed
x => x,
}
}
}
#[async_trait]

Loading…
Cancel
Save