|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
use crate::{
|
|
|
|
|
utils::Timer, ConnectionId, GenericServerRef, Interest, Listener, Request, Response, Server,
|
|
|
|
|
ServerConnection, ServerCtx, ServerRef, ServerReply, ServerState, Shutdown, TypedTransport,
|
|
|
|
|
utils::Timer, ConnectionId, FramedTransport, GenericServerRef, Interest, Listener, Request,
|
|
|
|
|
Response, Server, ServerConnection, ServerCtx, ServerRef, ServerReply, ServerState, Shutdown,
|
|
|
|
|
Transport, UntypedRequest,
|
|
|
|
|
};
|
|
|
|
|
use log::*;
|
|
|
|
|
use serde::{de::DeserializeOwned, Serialize};
|
|
|
|
@ -39,7 +40,7 @@ pub trait ServerExt {
|
|
|
|
|
fn start<L, T>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
|
|
|
|
|
where
|
|
|
|
|
L: Listener<Output = T> + 'static,
|
|
|
|
|
T: TypedTransport<Input = Self::Request, Output = Self::Response> + Send + 'static;
|
|
|
|
|
T: Transport + Send + 'static;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<S> ServerExt for S
|
|
|
|
@ -55,7 +56,7 @@ where
|
|
|
|
|
fn start<L, T>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
|
|
|
|
|
where
|
|
|
|
|
L: Listener<Output = T> + 'static,
|
|
|
|
|
T: TypedTransport<Input = Self::Request, Output = Self::Response> + Send + 'static,
|
|
|
|
|
T: Transport + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
let server = Arc::new(self);
|
|
|
|
|
let state = Arc::new(ServerState::new());
|
|
|
|
@ -68,12 +69,12 @@ where
|
|
|
|
|
|
|
|
|
|
async fn task<S, L, T>(server: Arc<S>, state: Arc<ServerState>, mut listener: L)
|
|
|
|
|
where
|
|
|
|
|
S: Server<Request = T::Input, Response = T::Output> + Sync + 'static,
|
|
|
|
|
S: Server + Sync + 'static,
|
|
|
|
|
S::Request: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
S::Response: Serialize + Send + 'static,
|
|
|
|
|
S::LocalData: Default + Send + Sync + 'static,
|
|
|
|
|
L: Listener<Output = T> + 'static,
|
|
|
|
|
T: TypedTransport + Send + 'static,
|
|
|
|
|
T::Input: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
T::Output: Serialize + Send + 'static,
|
|
|
|
|
T: Transport + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
// Grab a copy of our server's configuration so we can leverage it below
|
|
|
|
|
let config = server.config();
|
|
|
|
@ -183,17 +184,17 @@ struct ConnectionTask<S, T, D> {
|
|
|
|
|
server: Arc<S>,
|
|
|
|
|
state: Weak<ServerState>,
|
|
|
|
|
transport: T,
|
|
|
|
|
local_data: D,
|
|
|
|
|
local_data: Arc<D>,
|
|
|
|
|
shutdown_timer: Weak<Mutex<Timer<()>>>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<S, T, D> ConnectionTask<S, T, D>
|
|
|
|
|
where
|
|
|
|
|
S: Server<Request = T::Input, Response = T::Output, LocalData = D> + Sync + 'static,
|
|
|
|
|
S: Server<LocalData = D> + Sync + 'static,
|
|
|
|
|
S::Request: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
S::Response: Serialize + Send + 'static,
|
|
|
|
|
D: Default + Send + Sync + 'static,
|
|
|
|
|
T: TypedTransport + Send + 'static,
|
|
|
|
|
T::Input: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
|
T::Output: Serialize + Send + 'static,
|
|
|
|
|
T: Transport + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
pub fn spawn(self) -> JoinHandle<()> {
|
|
|
|
|
tokio::spawn(self.run())
|
|
|
|
@ -203,7 +204,11 @@ where
|
|
|
|
|
let connection_id = self.id;
|
|
|
|
|
|
|
|
|
|
// Construct a queue of outgoing responses
|
|
|
|
|
let (tx, mut rx) = mpsc::channel::<Response<T::Output>>(1);
|
|
|
|
|
let (tx, mut rx) = mpsc::channel::<Response<S::Response>>(1);
|
|
|
|
|
|
|
|
|
|
// TODO: We should perform a handshake here to determine which codec(s) to use in
|
|
|
|
|
// collaboration with the client
|
|
|
|
|
let mut transport = FramedTransport::new(self.transport);
|
|
|
|
|
|
|
|
|
|
loop {
|
|
|
|
|
let ready = self
|
|
|
|
@ -213,31 +218,48 @@ where
|
|
|
|
|
.expect("[Conn {connection_id}] Failed to examine ready state");
|
|
|
|
|
|
|
|
|
|
if ready.is_readable() {
|
|
|
|
|
match self.transport.try_read() {
|
|
|
|
|
Ok(Some(request)) => {
|
|
|
|
|
let reply = ServerReply {
|
|
|
|
|
origin_id: request.id.clone(),
|
|
|
|
|
tx: tx.clone(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let ctx = ServerCtx {
|
|
|
|
|
connection_id,
|
|
|
|
|
request,
|
|
|
|
|
reply: reply.clone(),
|
|
|
|
|
local_data: Arc::clone(&self.local_data),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
self.server.on_request(ctx).await;
|
|
|
|
|
}
|
|
|
|
|
match transport.try_read_frame() {
|
|
|
|
|
Ok(Some(frame)) => match UntypedRequest::from_slice(frame.as_item()) {
|
|
|
|
|
Ok(request) => match request.to_typed_request() {
|
|
|
|
|
Ok(request) => {
|
|
|
|
|
let reply = ServerReply {
|
|
|
|
|
origin_id: request.id.clone(),
|
|
|
|
|
tx: tx.clone(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let ctx = ServerCtx {
|
|
|
|
|
connection_id,
|
|
|
|
|
request,
|
|
|
|
|
reply: reply.clone(),
|
|
|
|
|
local_data: Arc::clone(&self.local_data),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
self.server.on_request(ctx).await;
|
|
|
|
|
}
|
|
|
|
|
Err(x) => {
|
|
|
|
|
if log::log_enabled!(Level::Trace) {
|
|
|
|
|
trace!(
|
|
|
|
|
"[Conn {connection_id}] Request payload: {}",
|
|
|
|
|
String::from_utf8_lossy(&request.payload),
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
error!("[Conn {connection_id}] Invalid request: {x}");
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
Err(x) => {
|
|
|
|
|
error!("[Conn {connection_id}] Invalid request: {x}");
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
Ok(None) => {
|
|
|
|
|
debug!("[Conn {connection_id}] Connection closed");
|
|
|
|
|
|
|
|
|
|
// Remove the connection from our state if it has closed
|
|
|
|
|
if let Some(state) = Weak::upgrade(&self.weak_state) {
|
|
|
|
|
state.connections.write().await.remove(&self.connection_id);
|
|
|
|
|
if let Some(state) = Weak::upgrade(&self.state) {
|
|
|
|
|
state.connections.write().await.remove(&self.id);
|
|
|
|
|
|
|
|
|
|
// If we have no more connections, start the timer
|
|
|
|
|
if let Some(timer) = Weak::upgrade(&self.weak_shutdown_timer) {
|
|
|
|
|
if let Some(timer) = Weak::upgrade(&self.shutdown_timer) {
|
|
|
|
|
if state.connections.read().await.is_empty() {
|
|
|
|
|
timer.lock().await.start();
|
|
|
|
|
}
|
|
|
|
@ -290,10 +312,7 @@ where
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
use super::*;
|
|
|
|
|
use crate::{
|
|
|
|
|
InmemoryTypedTransport, IntoSplit, MpscListener, MpscTransportReadHalf,
|
|
|
|
|
MpscTransportWriteHalf, ServerConfig,
|
|
|
|
|
};
|
|
|
|
|
use crate::{InmemoryTransport, MpscListener, ServerConfig};
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use std::time::Duration;
|
|
|
|
|
|
|
|
|
|