Unfinished port of client ext to builder

pull/146/head
Chip Senkbeil 2 years ago
parent fd684185cb
commit b74d57618e
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -12,12 +12,12 @@ use tokio::{
task::{JoinError, JoinHandle},
};
mod builder;
pub use builder::*;
mod channel;
pub use channel::*;
mod ext;
pub use ext::*;
/// Represents a client that can be used to send requests & receive responses from a server
pub struct Client<T, U> {
/// Used to send requests to a server
@ -150,6 +150,23 @@ where
}
impl<T, U> Client<T, U> {
/// Creates a new [`TcpClientBuilder`]
pub fn tcp() -> TcpClientBuilder<()> {
TcpClientBuilder::new()
}
/// Creates a new [`UnixSocketClientBuilder`]
#[cfg(unix)]
pub fn unix_socket() -> UnixSocketClientBuilder<()> {
UnixSocketClientBuilder::new()
}
/// Creates a new [`WindowsPipeClientBuilder`]
#[cfg(windows)]
pub fn windows_pipe() -> WindowsPipeClientBuilder<()> {
WindowsPipeClientBuilder::new()
}
/// Convert into underlying channel
pub fn into_channel(self) -> Channel<T, U> {
self.channel

@ -0,0 +1,77 @@
use crate::{
auth::{AuthHandler, Authenticator, FramedAuthenticator},
Client, FramedTransport, TcpTransport,
};
use serde::{de::DeserializeOwned, Serialize};
use std::convert;
use tokio::{io, net::ToSocketAddrs, time::Duration};
/// Builder for a client that will connect over TCP
pub struct TcpClientBuilder<T> {
auth_handler: T,
timeout: Option<Duration>,
}
impl<T> TcpClientBuilder<T> {
pub fn auth_handler<A: AuthHandler>(self, auth_handler: A) -> TcpClientBuilder<A> {
TcpClientBuilder {
auth_handler,
timeout: self.timeout,
}
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
timeout: timeout.into(),
}
}
}
impl TcpClientBuilder<()> {
pub fn new() -> Self {
Self {
auth_handler: (),
timeout: None,
}
}
}
impl Default for TcpClientBuilder<()> {
fn default() -> Self {
Self::new()
}
}
impl<A: AuthHandler + Send> TcpClientBuilder<A> {
pub async fn connect<T, U>(self, addr: impl ToSocketAddrs) -> io::Result<Client<T, U>>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
let auth_handler = self.auth_handler;
let timeout = self.timeout;
let f = async move {
let transport = TcpTransport::connect(addr).await?;
// Establish our framed transport, perform a handshake to set the codec, and do
// authentication to ensure the connection can be used
let mut transport = FramedTransport::<_>::plain(transport);
transport.client_handshake().await?;
FramedAuthenticator::new(&mut transport)
.authenticate(auth_handler)
.await?;
Ok(Client::new(transport))
};
match timeout {
Some(duration) => tokio::time::timeout(duration, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => f.await,
}
}
}

@ -0,0 +1,78 @@
use crate::{
auth::{AuthHandler, Authenticator, FramedAuthenticator},
Client, FramedTransport, UnixSocketTransport,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, path::Path};
use tokio::{io, time::Duration};
/// Builder for a client that will connect over a Unix socket
pub struct UnixSocketClientBuilder<T> {
auth_handler: T,
timeout: Option<Duration>,
}
impl<T> UnixSocketClientBuilder<T> {
pub fn auth_handler<A: AuthHandler>(self, auth_handler: A) -> UnixSocketClientBuilder<A> {
UnixSocketClientBuilder {
auth_handler,
timeout: self.timeout,
}
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
timeout: timeout.into(),
}
}
}
impl UnixSocketClientBuilder<()> {
pub fn new() -> Self {
Self {
auth_handler: (),
timeout: None,
}
}
}
impl Default for UnixSocketClientBuilder<()> {
fn default() -> Self {
Self::new()
}
}
impl<A: AuthHandler + Send> UnixSocketClientBuilder<A> {
pub async fn connect<T, U>(self, path: impl AsRef<Path> + Send) -> io::Result<Client<T, U>>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
let auth_handler = self.auth_handler;
let timeout = self.timeout;
let f = async move {
let p = path.as_ref();
let transport = UnixSocketTransport::connect(p).await?;
// Establish our framed transport, perform a handshake to set the codec, and do
// authentication to ensure the connection can be used
let mut transport = FramedTransport::<_>::plain(transport);
transport.client_handshake().await?;
FramedAuthenticator::new(&mut transport)
.authenticate(auth_handler)
.await?;
Ok(Client::new(transport))
};
match timeout {
Some(duration) => tokio::time::timeout(duration, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => f.await,
}
}
}

@ -0,0 +1,101 @@
use crate::{
auth::{AuthHandler, Authenticator, FramedAuthenticator},
Client, FramedTransport, WindowsPipeTransport,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert,
ffi::{OsStr, OsString},
};
use tokio::{io, time::Duration};
/// Builder for a client that will connect over a Windows pipe
pub struct WindowsPipeClientBuilder<T> {
auth_handler: T,
local: bool,
timeout: Option<Duration>,
}
impl<T> WindowsPipeClientBuilder<T> {
pub fn auth_handler<A: AuthHandler>(self, auth_handler: A) -> WindowsPipeClientBuilder<A> {
WindowsPipeClientBuilder {
auth_handler,
local: self.local,
timeout: self.timeout,
}
}
/// If true, will connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}`; otherwise, will connect using the address verbatim.
pub fn local(self, local: bool) -> Self {
Self {
auth_handler: self.auth_handler,
local,
timeout: self.timeout,
}
}
pub fn timeout(self, timeout: impl Into<Option<Duration>>) -> Self {
Self {
auth_handler: self.auth_handler,
local: self.local,
timeout: timeout.into(),
}
}
}
impl WindowsPipeClientBuilder<()> {
pub fn new() -> Self {
Self {
auth_handler: (),
local: false,
timeout: None,
}
}
}
impl Default for WindowsPipeClientBuilder<()> {
fn default() -> Self {
Self::new()
}
}
impl<A: AuthHandler + Send> WindowsPipeClientBuilder<A> {
pub async fn connect<T, U>(self, addr: impl AsRef<OsStr> + Send) -> io::Result<Client<T, U>>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
let auth_handler = self.auth_handler;
let timeout = self.timeout;
let f = async move {
let transport = if self.local {
let mut full_addr = OsString::from(r"\\.\pipe\");
full_addr.push(addr.as_ref());
WindowsPipeTransport::connect(full_addr)
} else {
WindowsPipeTransport::connect(addr.as_ref())
}
.await?;
// Establish our framed transport, perform a handshake to set the codec, and do
// authentication to ensure the connection can be used
let mut transport = FramedTransport::<_>::plain(transport);
transport.client_handshake().await?;
FramedAuthenticator::new(&mut transport)
.authenticate(auth_handler)
.await?;
Ok(Client::new(transport))
};
match timeout {
Some(duration) => tokio::time::timeout(duration, f)
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity),
None => f.await,
}
}
}

@ -1,43 +0,0 @@
use crate::{Client, FramedTransport, TcpTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, net::SocketAddr};
use tokio::{io, time::Duration};
/// Interface that provides ability to connect to a TCP server
#[async_trait]
pub trait TcpClientExt<T, U>
where
T: Serialize + Send + Sync,
U: DeserializeOwned + Send + Sync,
{
/// Connect to a remote TCP server using the provided information
async fn connect(addr: SocketAddr) -> io::Result<Client<T, U>>;
/// Connect to a remote TCP server, timing out after duration has passed
async fn connect_timeout<C>(addr: SocketAddr, duration: Duration) -> io::Result<Client<T, U>> {
tokio::time::timeout(duration, Self::connect(addr))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl<T, U> TcpClientExt<T, U> for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Connect to a remote TCP server using the provided information
async fn connect(addr: SocketAddr) -> io::Result<Client<T, U>> {
let transport = TcpTransport::connect(addr).await?;
// Establish our framed transport and perform a handshake to set the codec
// NOTE: Using default capacity
let mut transport = FramedTransport::<_>::plain(transport);
transport.client_handshake().await?;
Ok(Self::new(transport))
}
}

@ -1,51 +0,0 @@
use crate::{Client, FramedTransport, UnixSocketTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{convert, path::Path};
use tokio::{io, time::Duration};
#[async_trait]
pub trait UnixSocketClientExt<T, U>
where
T: Serialize + Send + Sync,
U: DeserializeOwned + Send + Sync,
{
/// Connect to a proxy unix socket
async fn connect<P>(path: P) -> io::Result<Client<T, U>>
where
P: AsRef<Path> + Send;
/// Connect to a proxy unix socket, timing out after duration has passed
async fn connect_timeout<P>(path: P, duration: Duration) -> io::Result<Client<T, U>>
where
P: AsRef<Path> + Send,
{
tokio::time::timeout(duration, Self::connect(path))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
}
#[async_trait]
impl<T, U> UnixSocketClientExt<T, U> for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
/// Connect to a proxy unix socket
async fn connect<P>(path: P) -> io::Result<Client<T, U>>
where
P: AsRef<Path> + Send,
{
let p = path.as_ref();
let transport = UnixSocketTransport::connect(p).await?;
// Establish our framed transport and perform a handshake to set the codec
// NOTE: Using default capacity
let mut transport = FramedTransport::<_>::plain(transport);
transport.client_handshake().await?;
Ok(Client::new(transport))
}
}

@ -1,77 +0,0 @@
use crate::{Client, FramedTransport, WindowsPipeTransport};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert,
ffi::{OsStr, OsString},
};
use tokio::{io, time::Duration};
#[async_trait]
pub trait WindowsPipeClientExt<T, U>
where
T: Serialize + Send + Sync,
U: DeserializeOwned + Send + Sync,
{
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec
async fn connect<A>(addr: A) -> io::Result<Client<T, U>>
where
A: AsRef<OsStr> + Send;
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}` using the given codec
async fn connect_local<N>(name: N) -> io::Result<Client<T, U>>
where
N: AsRef<OsStr> + Send,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect(addr).await
}
/// Connect to a server listening on a Windows pipe at the specified address
/// using the given codec, timing out after duration has passed
async fn connect_timeout<A>(addr: A, duration: Duration) -> io::Result<Client<T, U>>
where
A: AsRef<OsStr> + Send,
{
tokio::time::timeout(duration, Self::connect(addr))
.await
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
.and_then(convert::identity)
}
/// Connect to a server listening on a Windows pipe at the specified address
/// via `\\.\pipe\{name}`, timing out after duration has passed
async fn connect_local_timeout<N>(name: N, duration: Duration) -> io::Result<Client<T, U>>
where
N: AsRef<OsStr> + Send,
{
let mut addr = OsString::from(r"\\.\pipe\");
addr.push(name.as_ref());
Self::connect_timeout(addr, duration).await
}
}
#[async_trait]
impl<T, U> WindowsPipeClientExt<T, U> for Client<T, U>
where
T: Send + Sync + Serialize + 'static,
U: Send + Sync + DeserializeOwned + 'static,
{
async fn connect<A>(addr: A) -> io::Result<Client<T, U>>
where
A: AsRef<OsStr> + Send,
{
let a = addr.as_ref();
let transport = WindowsPipeTransport::connect(a).await?;
// Establish our framed transport and perform a handshake to set the codec
// NOTE: Using default capacity
let mut transport = FramedTransport::<_>::plain(transport);
transport.client_handshake().await?;
Ok(Client::new(transport))
}
}

@ -41,7 +41,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, Request, ServerCtx, TcpClientExt};
use crate::{Client, Request, ServerCtx};
use std::net::{Ipv6Addr, SocketAddr};
pub struct TestServer;
@ -67,10 +67,10 @@ mod tests {
.await
.expect("Failed to start TCP server");
let mut client: Client<String, String> =
Client::connect(SocketAddr::from((server.ip_addr(), server.port())))
.await
.expect("Client failed to connect");
let mut client: Client<String, String> = Client::tcp()
.connect(SocketAddr::from((server.ip_addr(), server.port())))
.await
.expect("Client failed to connect");
let response = client
.send(Request::new("hello".to_string()))

@ -43,7 +43,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{Client, Request, ServerCtx, UnixSocketClientExt};
use crate::{Client, Request, ServerCtx};
use tempfile::NamedTempFile;
pub struct TestServer;
@ -75,7 +75,8 @@ mod tests {
.await
.expect("Failed to start Unix socket server");
let mut client: Client<String, String> = Client::connect(server.path())
let mut client: Client<String, String> = Client::unix_socket()
.connect(server.path())
.await
.expect("Client failed to connect");

@ -1,169 +0,0 @@
use distant_net::{
AuthClient, AuthErrorKind, AuthQuestion, AuthRequest, AuthServer, AuthVerifyKind, Client,
InmemoryTypedTransport, IntoSplit, MpscListener, ServerExt,
};
use std::collections::HashMap;
use tokio::sync::mpsc;
/// Spawns a server and client connected together, returning the client
fn setup() -> (AuthClient, mpsc::Receiver<AuthRequest>) {
// Make a pair of inmemory transports that we can use to test client and server connected
let (t1, t2) = InmemoryTypedTransport::pair(100);
// Create the client
let (writer, reader) = t1.into_split();
let client = AuthClient::from(Client::new(writer, reader).unwrap());
// Prepare a channel where we can pass back out whatever request we get
let (tx, rx) = mpsc::channel(100);
let tx_2 = tx.clone();
let tx_3 = tx.clone();
let tx_4 = tx.clone();
// Make a server that echos questions back as answers and only verifies the text "yes"
let server = AuthServer {
on_challenge: move |questions, options| {
let questions_2 = questions.clone();
tx.try_send(AuthRequest::Challenge { questions, options })
.unwrap();
questions_2.into_iter().map(|x| x.text).collect()
},
on_verify: move |kind, text| {
let valid = text == "yes";
tx_2.try_send(AuthRequest::Verify { kind, text }).unwrap();
valid
},
on_info: move |text| {
tx_3.try_send(AuthRequest::Info { text }).unwrap();
},
on_error: move |kind, text| {
tx_4.try_send(AuthRequest::Error { kind, text }).unwrap();
},
};
// Spawn the server to listen for our client to connect
tokio::spawn(async move {
let (writer, reader) = t2.into_split();
let (tx, listener) = MpscListener::channel(1);
tx.send((writer, reader)).await.unwrap();
let _server = server.start(listener).unwrap();
});
(client, rx)
}
#[tokio::test]
async fn client_should_be_able_to_challenge_against_server() {
let (mut client, mut rx) = setup();
// Gotta start with the handshake first
client.handshake().await.unwrap();
// Now do the challenge
assert_eq!(
client
.challenge(
vec![AuthQuestion::new("hello".to_string())],
Default::default()
)
.await
.unwrap(),
vec!["hello".to_string()]
);
// Verify that the server received the request
let request = rx.recv().await.unwrap();
match request {
AuthRequest::Challenge { questions, options } => {
assert_eq!(questions.len(), 1);
assert_eq!(questions[0].text, "hello");
assert_eq!(questions[0].options, HashMap::new());
assert_eq!(options, HashMap::new());
}
x => panic!("Unexpected request received by server: {:?}", x),
}
}
#[tokio::test]
async fn client_should_be_able_to_verify_against_server() {
let (mut client, mut rx) = setup();
// Gotta start with the handshake first
client.handshake().await.unwrap();
// "no" will yield false
assert!(!client
.verify(AuthVerifyKind::Host, "no".to_string())
.await
.unwrap());
// Verify that the server received the request
let request = rx.recv().await.unwrap();
match request {
AuthRequest::Verify { kind, text } => {
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "no");
}
x => panic!("Unexpected request received by server: {:?}", x),
}
// "yes" will yield true
assert!(client
.verify(AuthVerifyKind::Host, "yes".to_string())
.await
.unwrap());
// Verify that the server received the request
let request = rx.recv().await.unwrap();
match request {
AuthRequest::Verify { kind, text } => {
assert_eq!(kind, AuthVerifyKind::Host);
assert_eq!(text, "yes");
}
x => panic!("Unexpected request received by server: {:?}", x),
}
}
#[tokio::test]
async fn client_should_be_able_to_send_info_to_server() {
let (mut client, mut rx) = setup();
// Gotta start with the handshake first
client.handshake().await.unwrap();
// Send some information
client.info(String::from("hello, world")).await.unwrap();
// Verify that the server received the request
let request = rx.recv().await.unwrap();
match request {
AuthRequest::Info { text } => assert_eq!(text, "hello, world"),
x => panic!("Unexpected request received by server: {:?}", x),
}
}
#[tokio::test]
async fn client_should_be_able_to_send_error_to_server() {
let (mut client, mut rx) = setup();
// Gotta start with the handshake first
client.handshake().await.unwrap();
// Send some error
client
.error(AuthErrorKind::Unknown, String::from("hello, world"))
.await
.unwrap();
// Verify that the server received the request
let request = rx.recv().await.unwrap();
match request {
AuthRequest::Error { kind, text } => {
assert_eq!(kind, AuthErrorKind::Unknown);
assert_eq!(text, "hello, world");
}
x => panic!("Unexpected request received by server: {:?}", x),
}
}
Loading…
Cancel
Save