Update server to leverage a verifier for authentication methods

pull/146/head
Chip Senkbeil 2 years ago
parent 2173524e3b
commit 733a0e2a7d
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,16 +1,80 @@
use super::{msg::*, Authenticator};
use crate::HeapSecretKey;
use async_trait::async_trait;
use log::*;
use std::collections::HashMap;
use std::io;
/// Supports authenticating using a variety of methods
pub struct Verifier {
methods: HashMap<&'static str, Box<dyn AuthenticationMethod>>,
}
impl Verifier {
pub fn new<I>(methods: I) -> Self
where
I: IntoIterator<Item = Box<dyn AuthenticationMethod>>,
{
let mut m = HashMap::new();
for method in methods {
m.insert(method.id(), method);
}
Self { methods: m }
}
/// Creates a verifier with no methods.
pub fn empty() -> Self {
Self {
methods: HashMap::new(),
}
}
/// Returns an iterator over the ids of the methods supported by the verifier
pub fn methods(&self) -> impl Iterator<Item = &'static str> + '_ {
self.methods.keys().copied()
}
/// Attempts to verify by submitting challenges using the `authenticator` provided. Returns the
/// id of the authentication method that succeeded.
pub async fn verify(&self, authenticator: &mut dyn Authenticator) -> io::Result<&'static str> {
// Initiate the process to get methods to use
let response = authenticator
.initialize(Initialization {
methods: self.methods.keys().map(ToString::to_string).collect(),
})
.await?;
for method in response.methods {
match self.methods.get(method.as_str()) {
Some(method) => {
if method.authenticate(authenticator).await.is_ok() {
authenticator.finished().await?;
return Ok(method.id());
}
}
None => {
trace!("Skipping authentication {method} as it is not available or supported");
}
}
}
Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"No authentication method succeeded",
))
}
}
/// Represents an interface to authenticate using some method
#[async_trait]
pub trait AuthenticationMethod: Sized {
pub trait AuthenticationMethod: Send + Sync {
/// Returns a unique id to distinguish the method from other methods
fn id() -> &'static str;
fn id(&self) -> &'static str;
// TODO: add a unique id method and update below method to take dyn ref so it can be boxed.
// that way, we can pass to server a collection of boxed methods
/// Performs authentication using the `authenticator` to submit challenges and other
/// information based on the authentication method
async fn authenticate(&self, authenticator: &mut dyn Authenticator) -> io::Result<()>;
}
@ -34,7 +98,7 @@ impl Default for NoneAuthenticationMethod {
#[async_trait]
impl AuthenticationMethod for NoneAuthenticationMethod {
fn id() -> &'static str {
fn id(&self) -> &'static str {
"none"
}
@ -58,7 +122,7 @@ impl StaticKeyAuthenticationMethod {
#[async_trait]
impl AuthenticationMethod for StaticKeyAuthenticationMethod {
fn id() -> &'static str {
fn id(&self) -> &'static str {
"static_key"
}
@ -114,7 +178,7 @@ impl ReauthenticationMethod {
#[async_trait]
impl AuthenticationMethod for ReauthenticationMethod {
fn id() -> &'static str {
fn id(&self) -> &'static str {
"reauthentication"
}

@ -1,4 +1,7 @@
use crate::{auth::Authenticator, Listener, Transport};
use crate::{
auth::{Authenticator, Verifier},
Listener, Transport,
};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -37,6 +40,9 @@ pub struct Server<T> {
/// Handler used to process various server events
handler: T,
/// Performs authentication using various methods
verifier: Verifier,
}
/// Interface for a handler that receives connections and requests
@ -71,11 +77,13 @@ pub trait ServerHandler: Send {
}
impl Server<()> {
/// Creates a new [`Server`], starting with a default configuration and no [`ServerHandler`].
/// Creates a new [`Server`], starting with a default configuration, no authentication methods,
/// and no [`ServerHandler`].
pub fn new() -> Self {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
}
}
@ -109,6 +117,7 @@ impl<T> Server<T> {
Self {
config,
handler: self.handler,
verifier: self.verifier,
}
}
@ -117,6 +126,16 @@ impl<T> Server<T> {
Server {
config: self.config,
handler,
verifier: self.verifier,
}
}
/// Consumes the current server, replacing its verifier with `verifier` and returning it.
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
}
}
}
@ -147,12 +166,17 @@ where
L: Listener + 'static,
L::Output: Transport + Send + Sync + 'static,
{
let Server { config, handler } = self;
let Server {
config,
handler,
verifier,
} = self;
let handler = Arc::new(handler);
let timer = ShutdownTimer::start(config.shutdown);
let mut notification = timer.clone_notification();
let timer = Arc::new(RwLock::new(timer));
let verifier = Arc::new(verifier);
loop {
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
@ -185,6 +209,7 @@ where
.state(Arc::downgrade(&state))
.transport(transport)
.shutdown_timer(Arc::downgrade(&timer))
.verifier(Arc::downgrade(&verifier))
.spawn();
state
@ -200,7 +225,8 @@ where
mod tests {
use super::*;
use crate::{
auth::Authenticator, InmemoryTransport, MpscListener, Request, Response, ServerConfig,
auth::{AuthenticationMethod, Authenticator, NoneAuthenticationMethod},
InmemoryTransport, MpscListener, Request, Response, ServerConfig,
};
use async_trait::async_trait;
use std::time::Duration;
@ -230,9 +256,13 @@ mod tests {
#[inline]
fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(NoneAuthenticationMethod::new())];
Server {
config,
handler: TestServerHandler,
verifier: Verifier::new(methods),
}
}
@ -266,6 +296,7 @@ mod tests {
.await
.expect("Failed to send request");
// Wait for a response
let mut buf = [0u8; 1024];
let n = transport.try_read(&mut buf).unwrap();
let response: Response<String> = Response::from_slice(&buf[..n]).unwrap();

@ -1,10 +1,13 @@
use crate::{PortRange, Server, ServerConfig, ServerHandler, TcpListener, TcpServerRef};
use crate::{
auth::Verifier, PortRange, Server, ServerConfig, ServerHandler, TcpListener, TcpServerRef,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{io, net::IpAddr};
pub struct TcpServerBuilder<T> {
config: ServerConfig,
handler: T,
verifier: Verifier,
}
impl Default for TcpServerBuilder<()> {
@ -12,6 +15,7 @@ impl Default for TcpServerBuilder<()> {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
}
}
}
@ -21,6 +25,7 @@ impl<T> TcpServerBuilder<T> {
Self {
config,
handler: self.handler,
verifier: self.verifier,
}
}
@ -28,6 +33,15 @@ impl<T> TcpServerBuilder<T> {
TcpServerBuilder {
config: self.config,
handler,
verifier: self.verifier,
}
}
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
}
}
}
@ -48,6 +62,7 @@ where
let server = Server {
config: self.config,
handler: self.handler,
verifier: self.verifier,
};
let inner = server.start(listener)?;
Ok(TcpServerRef { addr, port, inner })

@ -1,10 +1,13 @@
use crate::{Server, ServerConfig, ServerHandler, UnixSocketListener, UnixSocketServerRef};
use crate::{
auth::Verifier, Server, ServerConfig, ServerHandler, UnixSocketListener, UnixSocketServerRef,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{io, path::Path};
pub struct UnixSocketServerBuilder<T> {
config: ServerConfig,
handler: T,
verifier: Verifier,
}
impl Default for UnixSocketServerBuilder<()> {
@ -12,6 +15,7 @@ impl Default for UnixSocketServerBuilder<()> {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
}
}
}
@ -21,6 +25,7 @@ impl<T> UnixSocketServerBuilder<T> {
Self {
config,
handler: self.handler,
verifier: self.verifier,
}
}
@ -28,6 +33,15 @@ impl<T> UnixSocketServerBuilder<T> {
UnixSocketServerBuilder {
config: self.config,
handler,
verifier: self.verifier,
}
}
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
}
}
}
@ -50,6 +64,7 @@ where
let server = Server {
config: self.config,
handler: self.handler,
verifier: self.verifier,
};
let inner = server.start(listener)?;
Ok(UnixSocketServerRef { path, inner })

@ -1,4 +1,6 @@
use crate::{Server, ServerConfig, ServerHandler, WindowsPipeListener, WindowsPipeServerRef};
use crate::{
auth::Verifier, Server, ServerConfig, ServerHandler, WindowsPipeListener, WindowsPipeServerRef,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{
ffi::{OsStr, OsString},
@ -8,6 +10,7 @@ use std::{
pub struct WindowsPipeServerBuilder<T> {
config: ServerConfig,
handler: T,
verifier: Verifier,
}
impl Default for WindowsPipeServerBuilder<()> {
@ -15,15 +18,25 @@ impl Default for WindowsPipeServerBuilder<()> {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
}
}
}
impl<T> WindowsPipeServerBuilder<T> {
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
}
}
pub fn config(self, config: ServerConfig) -> Self {
Self {
config,
handler: self.handler,
verifier: self.verifier,
}
}
@ -31,6 +44,7 @@ impl<T> WindowsPipeServerBuilder<T> {
WindowsPipeServerBuilder {
config: self.config,
handler,
verifier: self.verifier,
}
}
}
@ -54,6 +68,7 @@ where
let server = Server {
config: self.config,
handler: self.handler,
verifier: self.verifier,
};
let inner = server.start(listener)?;
Ok(WindowsPipeServerRef { addr, inner })
@ -85,7 +100,6 @@ mod tests {
Client, ConnectionCtx, Request, ServerCtx,
};
use async_trait::async_trait;
use std::collections::HashMap;
pub struct TestServerHandler;

@ -1,7 +1,7 @@
use super::{ServerState, ShutdownTimer};
use crate::{
ConnectionCtx, FramedTransport, Interest, Response, ServerCtx, ServerHandler, ServerReply,
Transport, UntypedRequest,
auth::Verifier, ConnectionCtx, FramedTransport, Interest, Response, ServerCtx, ServerHandler,
ServerReply, Transport, UntypedRequest,
};
use log::*;
use serde::{de::DeserializeOwned, Serialize};
@ -40,6 +40,7 @@ impl Connection {
state: Weak::new(),
transport: (),
shutdown_timer: Weak::new(),
verifier: Weak::new(),
}
}
@ -60,6 +61,7 @@ pub struct ConnectionBuilder<H, T> {
state: Weak<ServerState>,
transport: T,
shutdown_timer: Weak<RwLock<ShutdownTimer>>,
verifier: Weak<Verifier>,
}
impl<H, T> ConnectionBuilder<H, T> {
@ -70,6 +72,7 @@ impl<H, T> ConnectionBuilder<H, T> {
state: self.state,
transport: self.transport,
shutdown_timer: self.shutdown_timer,
verifier: self.verifier,
}
}
@ -80,6 +83,7 @@ impl<H, T> ConnectionBuilder<H, T> {
state,
transport: self.transport,
shutdown_timer: self.shutdown_timer,
verifier: self.verifier,
}
}
@ -90,6 +94,7 @@ impl<H, T> ConnectionBuilder<H, T> {
state: self.state,
transport,
shutdown_timer: self.shutdown_timer,
verifier: self.verifier,
}
}
@ -103,6 +108,18 @@ impl<H, T> ConnectionBuilder<H, T> {
state: self.state,
transport: self.transport,
shutdown_timer,
verifier: self.verifier,
}
}
pub fn verifier(self, verifier: Weak<Verifier>) -> ConnectionBuilder<H, T> {
ConnectionBuilder {
id: self.id,
handler: self.handler,
state: self.state,
transport: self.transport,
shutdown_timer: self.shutdown_timer,
verifier,
}
}
}
@ -131,6 +148,7 @@ where
state,
transport,
shutdown_timer,
verifier,
} = self;
// Attempt to upgrade our handler for use with the connection going forward
@ -152,6 +170,20 @@ where
return;
}
// Perform authentication to ensure the connection is valid
match Weak::upgrade(&verifier) {
Some(verifier) => {
if let Err(x) = verifier.verify(&mut transport).await {
error!("[Conn {id}] Verification failed: {x}");
return;
}
}
None => {
error!("[Conn {id}] Verifier has been dropped");
return;
}
};
// Create local data for the connection and then process it as well as perform
// authentication and any other tasks on first connecting
let mut local_data = H::LocalData::default();

Loading…
Cancel
Save