Refactor bulk of distant-net code into a common module such that we have three top-level modules: common, client, and server

pull/146/head
Chip Senkbeil 2 years ago
parent 52b4a92ee8
commit 436e73b903
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,4 +1,6 @@
use crate::{FramedTransport, Interest, Reconnectable, Request, Transport, UntypedResponse};
use crate::common::{
FramedTransport, Interest, Reconnectable, Request, Transport, UntypedResponse,
};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};

@ -13,9 +13,10 @@ mod windows;
#[cfg(windows)]
pub use windows::*;
use crate::{
use crate::client::Client;
use crate::common::{
auth::{AuthHandler, Authenticate},
Client, FramedTransport, Transport,
FramedTransport, Transport,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, future::Future, io, time::Duration};

@ -1,4 +1,5 @@
use crate::{auth::AuthHandler, Client, ClientBuilder, TcpTransport};
use crate::client::{Client, ClientBuilder};
use crate::common::{auth::AuthHandler, TcpTransport};
use serde::{de::DeserializeOwned, Serialize};
use tokio::{io, net::ToSocketAddrs, time::Duration};

@ -1,4 +1,5 @@
use crate::{auth::AuthHandler, Client, ClientBuilder, UnixSocketTransport};
use crate::client::{Client, ClientBuilder};
use crate::common::{auth::AuthHandler, UnixSocketTransport};
use serde::{de::DeserializeOwned, Serialize};
use std::path::Path;
use tokio::{io, time::Duration};

@ -1,12 +1,7 @@
use crate::{
auth::{AuthHandler, Authenticate},
Client, ClientBuilder, FramedTransport, WindowsPipeTransport,
};
use crate::client::{Client, ClientBuilder};
use crate::common::{auth::AuthHandler, WindowsPipeTransport};
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert,
ffi::{OsStr, OsString},
};
use std::ffi::{OsStr, OsString};
use tokio::{io, time::Duration};
/// Builder for a client that will connect over a Windows pipe
@ -35,7 +30,7 @@ impl<T> WindowsPipeClientBuilder<T> {
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
inner: self.inner.timeout(timeout),
timeout: timeout.into(),
local: self.local,
}
}
}

@ -1,4 +1,4 @@
use crate::{Request, Response};
use crate::common::{Request, Response};
use std::{convert, io, sync::Weak};
use tokio::{sync::mpsc, time::Duration};
@ -134,7 +134,8 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, FramedTransport, InmemoryTransport};
use crate::client::Client;
use crate::common::{FramedTransport, InmemoryTransport};
use std::time::Duration;
use test_log::test;

@ -1,4 +1,4 @@
use crate::{Id, Response};
use crate::common::{Id, Response};
use std::{
collections::HashMap,
sync::{Arc, Weak},

@ -0,0 +1,15 @@
mod any;
pub mod auth;
mod connection;
mod listener;
mod packet;
mod port;
mod transport;
pub(crate) mod utils;
pub use any::*;
pub use connection::*;
pub use listener::*;
pub use packet::*;
pub use port::*;
pub use transport::*;

@ -1,5 +1,5 @@
use super::{msg::*, AuthHandler};
use crate::{utils, FramedTransport, Transport};
use crate::common::{utils, FramedTransport, Transport};
use async_trait::async_trait;
use log::*;
use std::io;

@ -1,4 +1,4 @@
use crate::HeapSecretKey;
use crate::common::HeapSecretKey;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;

@ -99,7 +99,7 @@ pub trait AuthenticationMethod: Send + Sync {
#[cfg(test)]
mod tests {
use super::*;
use crate::FramedTransport;
use crate::common::FramedTransport;
use test_log::test;
struct SuccessAuthenticationMethod;

@ -1,5 +1,5 @@
use super::{AuthenticationMethod, Authenticator, Challenge, Error, Question};
use crate::HeapSecretKey;
use crate::common::HeapSecretKey;
use async_trait::async_trait;
use std::io;
@ -50,7 +50,7 @@ impl AuthenticationMethod for StaticKeyAuthenticationMethod {
#[cfg(test)]
mod tests {
use super::*;
use crate::{
use crate::common::{
auth::msg::{AuthenticationResponse, ChallengeResponse},
FramedTransport,
};

@ -231,7 +231,9 @@ where
/// performing the following:
///
/// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 2. Authenticates the established connection to ensure it is valid
/// 2. Authenticates the established connection to ensure it is valid by either using the
/// given `verifier` or, if working with an existing client connection, will validate an OTP
/// from our database
/// 3. Restores pre-existing state using the provided backup, replaying any missing frames and
/// receiving any frames from the other side
pub async fn server(transport: T, verifier: &Verifier, keychain: Keychain) -> io::Result<Self> {

@ -1,4 +1,4 @@
use crate::Listener;
use super::Listener;
use async_trait::async_trait;
use std::io;

@ -1,4 +1,4 @@
use crate::Listener;
use super::Listener;
use async_trait::async_trait;
use derive_more::From;
use std::io;

@ -1,4 +1,4 @@
use crate::Listener;
use super::Listener;
use async_trait::async_trait;
use derive_more::From;
use std::io;

@ -1,4 +1,5 @@
use crate::{Listener, PortRange, TcpTransport};
use super::Listener;
use crate::common::{PortRange, TcpTransport};
use async_trait::async_trait;
use std::{fmt, io, net::IpAddr};
use tokio::net::TcpListener as TokioTcpListener;
@ -64,7 +65,7 @@ impl Listener for TcpListener {
#[cfg(test)]
mod tests {
use super::*;
use crate::Transport;
use crate::common::Transport;
use std::net::{Ipv6Addr, SocketAddr};
use test_log::test;
use tokio::{sync::oneshot, task::JoinHandle};

@ -1,4 +1,5 @@
use crate::{Listener, UnixSocketTransport};
use super::Listener;
use crate::common::UnixSocketTransport;
use async_trait::async_trait;
use std::{
fmt, io,
@ -94,7 +95,7 @@ impl Listener for UnixSocketListener {
#[cfg(test)]
mod tests {
use super::*;
use crate::Transport;
use crate::common::Transport;
use tempfile::NamedTempFile;
use test_log::test;
use tokio::{sync::oneshot, task::JoinHandle};

@ -1,4 +1,5 @@
use crate::{Listener, NamedPipe, WindowsPipeTransport};
use super::Listener;
use crate::common::{NamedPipe, WindowsPipeTransport};
use async_trait::async_trait;
use std::{
ffi::{OsStr, OsString},
@ -66,7 +67,7 @@ impl Listener for WindowsPipeListener {
#[cfg(test)]
mod tests {
use super::*;
use crate::Transport;
use crate::common::Transport;
use test_log::test;
use tokio::{sync::oneshot, task::JoinHandle};

@ -1,5 +1,5 @@
use super::{parse_msg_pack_str, write_str_msg_pack, Id};
use crate::utils;
use crate::common::utils;
use derive_more::{Display, Error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{borrow::Cow, io, str};

@ -1,5 +1,5 @@
use super::{parse_msg_pack_str, write_str_msg_pack, Id};
use crate::utils;
use crate::common::utils;
use derive_more::{Display, Error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{borrow::Cow, io};

@ -1,5 +1,5 @@
use super::{InmemoryTransport, Interest, Ready, Reconnectable, Transport};
use crate::utils;
use crate::common::utils;
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
use log::*;
@ -802,7 +802,7 @@ impl FramedTransport<InmemoryTransport> {
#[cfg(test)]
mod tests {
use super::*;
use crate::TestTransport;
use crate::common::TestTransport;
use bytes::BufMut;
use test_log::test;

@ -1,22 +1,9 @@
mod any;
pub mod auth;
mod client;
mod connection;
mod listener;
mod packet;
mod port;
pub mod common;
mod server;
mod transport;
mod utils;
pub use any::*;
pub use client::*;
pub use connection::*;
pub use listener::*;
pub use packet::*;
pub use port::*;
pub use server::*;
pub use transport::*;
pub use log;
pub use paste;

@ -1,4 +1,4 @@
use crate::{auth::Verifier, Listener, Transport};
use crate::common::{auth::Verifier, Listener, Transport};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -12,7 +12,6 @@ mod config;
pub use config::*;
mod connection;
pub use connection::ConnectionId;
use connection::*;
mod context;
@ -198,7 +197,7 @@ where
// Ensure that the shutdown timer is cancelled now that we have a connection
timer.read().await.stop();
let connection = Connection::build()
let connection = ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.transport(transport)
@ -219,9 +218,9 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{
use crate::common::{
auth::{Authenticate, AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod},
FramedTransport, InmemoryTransport, MpscListener, Request, Response, ServerConfig,
FramedTransport, InmemoryTransport, MpscListener, Request, Response,
};
use async_trait::async_trait;
use std::time::Duration;

@ -1,6 +1,5 @@
use crate::{
auth::Verifier, PortRange, Server, ServerConfig, ServerHandler, TcpListener, TcpServerRef,
};
use crate::common::{auth::Verifier, PortRange, TcpListener};
use crate::server::{Server, ServerConfig, ServerHandler, TcpServerRef};
use serde::{de::DeserializeOwned, Serialize};
use std::{io, net::IpAddr};
@ -54,7 +53,9 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{auth::DummyAuthHandler, Client, Request, ServerCtx};
use crate::client::Client;
use crate::common::{auth::DummyAuthHandler, Request};
use crate::server::ServerCtx;
use async_trait::async_trait;
use std::net::{Ipv6Addr, SocketAddr};
use test_log::test;

@ -1,6 +1,5 @@
use crate::{
auth::Verifier, Server, ServerConfig, ServerHandler, UnixSocketListener, UnixSocketServerRef,
};
use crate::common::{auth::Verifier, UnixSocketListener};
use crate::server::{Server, ServerConfig, ServerHandler, UnixSocketServerRef};
use serde::{de::DeserializeOwned, Serialize};
use std::{io, path::Path};
@ -55,7 +54,9 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{auth::DummyAuthHandler, Client, Request, ServerCtx};
use crate::client::Client;
use crate::common::{auth::DummyAuthHandler, Request};
use crate::server::ServerCtx;
use async_trait::async_trait;
use tempfile::NamedTempFile;
use test_log::test;

@ -1,6 +1,5 @@
use crate::{
auth::Verifier, Server, ServerConfig, ServerHandler, WindowsPipeListener, WindowsPipeServerRef,
};
use crate::common::{auth::Verifier, WindowsPipeListener};
use crate::server::{Server, ServerConfig, ServerHandler, WindowsPipeServerRef};
use serde::{de::DeserializeOwned, Serialize};
use std::{
ffi::{OsStr, OsString},
@ -70,10 +69,9 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{
auth::{Client, DummyAuthHandler, Request, ServerCtx},
Client, ConnectionCtx, Request, ServerCtx,
};
use crate::client::Client;
use crate::common::{auth::DummyAuthHandler, Request};
use crate::server::ServerCtx;
use async_trait::async_trait;
use test_log::test;

@ -1,7 +1,6 @@
use super::{ServerState, ShutdownTimer};
use crate::{
auth::Verifier, ConnectionCtx, Interest, Response, ServerCtx, ServerHandler, ServerReply,
Transport, UntypedRequest,
use super::{ConnectionCtx, ServerCtx, ServerHandler, ServerReply, ServerState, ShutdownTimer};
use crate::common::{
auth::Verifier, Connection, ConnectionId, Interest, Response, Transport, UntypedRequest,
};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -19,7 +18,7 @@ use tokio::{
const SLEEP_DURATION: Duration = Duration::from_millis(50);
/// Represents an individual connection on the server
pub struct Connection {
pub struct ConnectionTask {
/// Unique identifier tied to the connection
id: ConnectionId,
@ -27,11 +26,11 @@ pub struct Connection {
task: JoinHandle<()>,
}
impl Connection {
impl ConnectionTask {
/// Starts building a new connection
pub fn build() -> ConnectionBuilder<(), ()> {
pub fn build() -> ConnectionTaskBuilder<(), ()> {
let id: ConnectionId = rand::random();
ConnectionBuilder {
ConnectionTaskBuilder {
id,
handler: Weak::new(),
state: Weak::new(),
@ -53,7 +52,7 @@ impl Connection {
}
}
pub struct ConnectionBuilder<H, T> {
pub struct ConnectionTaskBuilder<H, T> {
id: ConnectionId,
handler: Weak<H>,
state: Weak<ServerState>,
@ -63,9 +62,9 @@ pub struct ConnectionBuilder<H, T> {
verifier: Weak<Verifier>,
}
impl<H, T> ConnectionBuilder<H, T> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionBuilder<U, T> {
ConnectionBuilder {
impl<H, T> ConnectionTaskBuilder<H, T> {
pub fn handler<U>(self, handler: Weak<U>) -> ConnectionTaskBuilder<U, T> {
ConnectionTaskBuilder {
id: self.id,
handler,
state: self.state,
@ -76,8 +75,8 @@ impl<H, T> ConnectionBuilder<H, T> {
}
}
pub fn state(self, state: Weak<ServerState>) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
pub fn state(self, state: Weak<ServerState>) -> ConnectionTaskBuilder<H, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state,
@ -88,8 +87,8 @@ impl<H, T> ConnectionBuilder<H, T> {
}
}
pub fn transport<U>(self, transport: U) -> ConnectionBuilder<H, U> {
ConnectionBuilder {
pub fn transport<U>(self, transport: U) -> ConnectionTaskBuilder<H, U> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
@ -103,8 +102,8 @@ impl<H, T> ConnectionBuilder<H, T> {
pub(crate) fn shutdown_timer(
self,
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
) -> ConnectionTaskBuilder<H, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
@ -115,8 +114,8 @@ impl<H, T> ConnectionBuilder<H, T> {
}
}
pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
pub fn sleep_duration(self, sleep_duration: Duration) -> ConnectionTaskBuilder<H, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
@ -127,8 +126,8 @@ impl<H, T> ConnectionBuilder<H, T> {
}
}
pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionTaskBuilder<H, T> {
ConnectionTaskBuilder {
id: self.id,
handler: self.handler,
state: self.state,
@ -140,7 +139,7 @@ impl<H, T> ConnectionBuilder<H, T> {
}
}
impl<H, T> ConnectionBuilder<H, T>
impl<H, T> ConnectionTaskBuilder<H, T>
where
H: ServerHandler + Sync + 'static,
H::Request: DeserializeOwned + Send + Sync + 'static,
@ -148,17 +147,17 @@ where
H::LocalData: Default + Send + Sync + 'static,
T: Transport + Send + Sync + 'static,
{
pub fn spawn(self) -> Connection {
pub fn spawn(self) -> ConnectionTask {
let id = self.id;
Connection {
ConnectionTask {
id,
task: tokio::spawn(self.run()),
}
}
async fn run(self) {
let ConnectionBuilder {
let ConnectionTaskBuilder {
id,
handler,
state,
@ -203,7 +202,7 @@ where
// Properly establish the connection's transport
let mut transport = match Weak::upgrade(&verifier) {
Some(verifier) => {
match crate::Connection::server(transport, verifier.as_ref(), keychain).await {
match Connection::server(transport, verifier.as_ref(), keychain).await {
Ok(connection) => connection.into_transport(),
Err(x) => {
terminate_connection!(@error "[Conn {id}] Failed to setup connection: {x}");
@ -299,7 +298,7 @@ where
}
},
Err(x) => {
error!("[Conn {id}] Invalid request: {x}");
error!("[Conn {id}] Invalid request payload: {x}");
}
},
Ok(None) => {

@ -1,4 +1,5 @@
use crate::{ConnectionId, Request, ServerReply};
use super::ServerReply;
use crate::common::{ConnectionId, Request};
use std::sync::Arc;
/// Represents contextual information for working with an inbound request

@ -1,5 +1,5 @@
use super::ServerState;
use crate::AsAny;
use crate::common::AsAny;
use log::*;
use std::{
future::Future,

@ -1,4 +1,4 @@
use crate::ServerRef;
use super::ServerRef;
use std::net::IpAddr;
/// Reference to a TCP server instance

@ -1,4 +1,4 @@
use crate::ServerRef;
use super::ServerRef;
use std::path::{Path, PathBuf};
/// Reference to a unix socket server instance

@ -1,4 +1,4 @@
use crate::ServerRef;
use super::ServerRef;
use std::ffi::{OsStr, OsString};
/// Reference to a unix socket server instance

@ -1,4 +1,4 @@
use crate::{Id, Response};
use crate::common::{Id, Response};
use std::{future::Future, io, pin::Pin, sync::Arc};
use tokio::sync::{mpsc, Mutex};

@ -1,5 +1,5 @@
use super::Shutdown;
use crate::utils::Timer;
use crate::common::utils::Timer;
use log::*;
use std::time::Duration;
use tokio::sync::watch;

@ -1,12 +1,12 @@
use super::{Connection, ConnectionId};
use crate::HeapSecretKey;
use super::ConnectionTask;
use crate::common::{ConnectionId, HeapSecretKey};
use std::collections::HashMap;
use tokio::sync::RwLock;
/// Contains all top-level state for the server
pub struct ServerState {
/// Mapping of connection ids to their transports
pub connections: RwLock<HashMap<ConnectionId, Connection>>,
pub connections: RwLock<HashMap<ConnectionId, ConnectionTask>>,
/// Mapping of connection ids to their authenticated keys
pub authenticated: RwLock<HashMap<ConnectionId, HeapSecretKey>>,

Loading…
Cancel
Save