Add a proxy auth handler that uses a mutable reference to an authenticator

pull/146/head
Chip Senkbeil 2 years ago
parent 920c3c5578
commit 8fb2c5ddb8
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -1,4 +1,5 @@
use super::msg::*;
use crate::common::authentication::Authenticator;
use crate::common::HeapSecretKey;
use async_trait::async_trait;
use std::collections::HashMap;
@ -9,7 +10,7 @@ pub use methods::*;
/// Interface for a handler of authentication requests for all methods.
#[async_trait]
pub trait AuthHandler: AuthMethodHandler {
pub trait AuthHandler: AuthMethodHandler + Send {
/// Callback when authentication is beginning, providing available authentication methods and
/// returning selected authentication methods to pursue.
async fn on_initialization(
@ -69,7 +70,7 @@ impl AuthMethodHandler for DummyAuthHandler {
/// [`on_error`]: AuthMethodHandler::on_error
pub struct AuthHandlerMap {
active: String,
map: HashMap<&'static str, Box<dyn AuthMethodHandler + Send>>,
map: HashMap<&'static str, Box<dyn AuthMethodHandler>>,
}
impl AuthHandlerMap {
@ -93,11 +94,11 @@ impl AuthHandlerMap {
/// Inserts the specified `handler` into the map, associating it with `id` for determining the
/// method that would trigger this handler.
pub fn insert_method_handler<T: AuthMethodHandler + Send + 'static>(
pub fn insert_method_handler<T: AuthMethodHandler + 'static>(
&mut self,
id: &'static str,
handler: T,
) -> Option<Box<dyn AuthMethodHandler + Send>> {
) -> Option<Box<dyn AuthMethodHandler>> {
self.map.insert(id, Box::new(handler))
}
@ -105,7 +106,7 @@ impl AuthHandlerMap {
pub fn remove_method_handler(
&mut self,
id: &'static str,
) -> Option<Box<dyn AuthMethodHandler + Send>> {
) -> Option<Box<dyn AuthMethodHandler>> {
self.map.remove(id)
}
@ -113,7 +114,7 @@ impl AuthHandlerMap {
/// returning an error if no handler for the active id is found.
pub fn get_mut_active_method_handler_or_error(
&mut self,
) -> io::Result<&mut (dyn AuthMethodHandler + Send + 'static)> {
) -> io::Result<&mut (dyn AuthMethodHandler + 'static)> {
self.get_mut_active_method_handler()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "No active handler for id"))
}
@ -121,7 +122,7 @@ impl AuthHandlerMap {
/// Retrieves a mutable reference to the active [`AuthMethodHandler`] with the specified `id`.
pub fn get_mut_active_method_handler(
&mut self,
) -> Option<&mut (dyn AuthMethodHandler + Send + 'static)> {
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
// TODO: Optimize this
self.get_mut_method_handler(&self.active.clone())
}
@ -130,7 +131,7 @@ impl AuthHandlerMap {
pub fn get_mut_method_handler(
&mut self,
id: &str,
) -> Option<&mut (dyn AuthMethodHandler + Send + 'static)> {
) -> Option<&mut (dyn AuthMethodHandler + 'static)> {
self.map.get_mut(id).map(|h| h.as_mut())
}
}
@ -199,3 +200,52 @@ impl AuthMethodHandler for AuthHandlerMap {
handler.on_error(error).await
}
}
/// Implementation of [`AuthHandler`] that redirects all requests to an [`Authenticator`].
pub struct ProxyAuthHandler<'a>(&'a mut dyn Authenticator);
impl<'a> ProxyAuthHandler<'a> {
pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
Self(authenticator)
}
}
#[async_trait]
impl<'a> AuthHandler for ProxyAuthHandler<'a> {
async fn on_initialization(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
Authenticator::initialize(self.0, initialization).await
}
async fn on_start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
Authenticator::start_method(self.0, start_method).await
}
async fn on_finished(&mut self) -> io::Result<()> {
Authenticator::finished(self.0).await
}
}
#[async_trait]
impl<'a> AuthMethodHandler for ProxyAuthHandler<'a> {
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
Authenticator::challenge(self.0, challenge).await
}
async fn on_verification(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
Authenticator::verify(self.0, verification).await
}
async fn on_info(&mut self, info: Info) -> io::Result<()> {
Authenticator::info(self.0, info).await
}
async fn on_error(&mut self, error: Error) -> io::Result<()> {
Authenticator::error(self.0, error).await
}
}

@ -6,7 +6,7 @@ use std::io;
/// Interface for a handler of authentication requests for a specific authentication method.
#[async_trait]
pub trait AuthMethodHandler {
pub trait AuthMethodHandler: Send {
/// Callback when a challenge is received, returning answers to the given questions.
async fn on_challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse>;

@ -12,16 +12,13 @@ use std::io;
/// [`AuthMethodHandler`].
pub struct StaticKeyAuthMethodHandler {
key: HeapSecretKey,
handler: Box<dyn AuthMethodHandler + Send>,
handler: Box<dyn AuthMethodHandler>,
}
impl StaticKeyAuthMethodHandler {
/// Creates a new [`StaticKeyAuthMethodHandler`] that responds to challenges using a static
/// `key`. All other requests are passed to the `handler`.
pub fn new<T: AuthMethodHandler + Send + 'static>(
key: impl Into<HeapSecretKey>,
handler: T,
) -> Self {
pub fn new<T: AuthMethodHandler + 'static>(key: impl Into<HeapSecretKey>, handler: T) -> Self {
Self {
key: key.into(),
handler: Box::new(handler),

@ -1,19 +1,17 @@
use crate::config::ClientLaunchConfig;
use async_trait::async_trait;
use distant_core::{
net::{
AuthClient, AuthQuestion, FramedTransport, IntoSplit, SecretKey32, TcpTransport,
XChaCha20Poly1305Codec,
},
BoxedDistantReader, BoxedDistantWriter, BoxedDistantWriterReader, ConnectHandler, Destination,
LaunchHandler, Map,
};
use distant_core::net::client::{Client, ReconnectStrategy, UntypedClient};
use distant_core::net::common::authentication::msg::*;
use distant_core::net::common::authentication::{Authenticator, ProxyAuthHandler};
use distant_core::net::common::{Destination, Map};
use distant_core::net::manager::{ConnectHandler, LaunchHandler};
use log::*;
use std::{
io,
net::{IpAddr, SocketAddr},
path::PathBuf,
process::Stdio,
time::Duration,
};
use tokio::{
io::{AsyncBufReadExt, BufReader},
@ -50,7 +48,7 @@ impl LaunchHandler for ManagerLaunchHandler {
&self,
destination: &Destination,
options: &Map,
_auth_client: &mut AuthClient,
_auth_client: &mut dyn Authenticator,
) -> io::Result<Destination> {
debug!("Handling launch of {destination} with options '{options}'");
let config = ClientLaunchConfig::from(options.clone());
@ -163,7 +161,7 @@ impl LaunchHandler for SshLaunchHandler {
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
auth_client: &mut dyn Authenticator,
) -> io::Result<Destination> {
debug!("Handling launch of {destination} with options '{options}'");
let config = ClientLaunchConfig::from(options.clone());
@ -196,14 +194,31 @@ impl LaunchHandler for SshLaunchHandler {
pub struct DistantConnectHandler;
impl DistantConnectHandler {
pub async fn try_connect(ips: Vec<IpAddr>, port: u16) -> io::Result<TcpTransport> {
pub async fn try_connect(
ips: Vec<IpAddr>,
port: u16,
authenticator: &mut dyn Authenticator,
) -> io::Result<UntypedClient> {
// Try each IP address with the same port to see if one works
let mut err = None;
for ip in ips {
let addr = SocketAddr::new(ip, port);
debug!("Attempting to connect to distant server @ {}", addr);
match TcpTransport::connect(addr).await {
Ok(transport) => return Ok(transport),
match Client::tcp(addr)
.auth_handler(ProxyAuthHandler::new(authenticator))
.reconnect_strategy(ReconnectStrategy::ExponentialBackoff {
base: Duration::from_secs(1),
factor: 2.0,
max_duration: None,
max_retries: None,
timeout: None,
})
.timeout(Duration::from_secs(180))
.connect_untyped()
.await
{
Ok(client) => return Ok(client),
Err(x) => err = Some(x),
}
}
@ -219,8 +234,8 @@ impl ConnectHandler for DistantConnectHandler {
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
) -> io::Result<BoxedDistantWriterReader> {
authenticator: &mut dyn Authenticator,
) -> io::Result<UntypedClient> {
debug!("Handling connect of {destination} with options '{options}'");
let host = destination.host.to_string();
let port = destination.port.ok_or_else(|| missing("port"))?;
@ -246,37 +261,7 @@ impl ConnectHandler for DistantConnectHandler {
));
}
// Use provided password or options key if available, otherwise ask for it, and produce a
// codec using the key
let codec = {
let key = destination
.password
.as_deref()
.or_else(|| options.get("key").map(|s| s.as_str()));
let key = match key {
Some(key) => key.parse::<SecretKey32>().map_err(|_| invalid("key"))?,
None => {
let answers = auth_client
.challenge(vec![AuthQuestion::new("key")], Default::default())
.await?;
answers
.first()
.ok_or_else(|| missing("key"))?
.parse::<SecretKey32>()
.map_err(|_| invalid("key"))?
}
};
XChaCha20Poly1305Codec::from(key)
};
// Establish a TCP connection, wrap it, and split it out into a writer and reader
let transport = Self::try_connect(candidate_ips, port).await?;
let transport = FramedTransport::new(transport, codec);
let (writer, reader) = transport.into_split();
let writer: BoxedDistantWriter = Box::new(writer);
let reader: BoxedDistantReader = Box::new(reader);
Ok((writer, reader))
Self::try_connect(candidate_ips, port, authenticator).await
}
}
@ -291,23 +276,23 @@ impl ConnectHandler for SshConnectHandler {
&self,
destination: &Destination,
options: &Map,
auth_client: &mut AuthClient,
) -> io::Result<BoxedDistantWriterReader> {
auth_client: &mut dyn Authenticator,
) -> io::Result<UntypedClient> {
debug!("Handling connect of {destination} with options '{options}'");
let mut ssh = load_ssh(destination, options)?;
let handler = AuthClientSshAuthHandler::new(auth_client);
let _ = ssh.authenticate(handler).await?;
ssh.into_distant_writer_reader().await
Ok(ssh.into_distant_client().await?.into_untyped_client())
}
}
#[cfg(any(feature = "libssh", feature = "ssh2"))]
struct AuthClientSshAuthHandler<'a>(Mutex<&'a mut AuthClient>);
struct AuthClientSshAuthHandler<'a>(Mutex<&'a mut dyn Authenticator>);
#[cfg(any(feature = "libssh", feature = "ssh2"))]
impl<'a> AuthClientSshAuthHandler<'a> {
pub fn new(auth_client: &'a mut AuthClient) -> Self {
Self(Mutex::new(auth_client))
pub fn new(authenticator: &'a mut dyn Authenticator) -> Self {
Self(Mutex::new(authenticator))
}
}
@ -322,7 +307,8 @@ impl<'a> distant_ssh2::SshAuthHandler for AuthClientSshAuthHandler<'a> {
for prompt in event.prompts {
let mut options = HashMap::new();
options.insert("echo".to_string(), prompt.echo.to_string());
questions.push(AuthQuestion {
questions.push(Question {
label: "ssh-prompt".to_string(),
text: prompt.prompt,
options,
});
@ -331,31 +317,51 @@ impl<'a> distant_ssh2::SshAuthHandler for AuthClientSshAuthHandler<'a> {
options.insert("instructions".to_string(), event.instructions);
options.insert("username".to_string(), event.username);
self.0.lock().await.challenge(questions, options).await
Ok(self
.0
.lock()
.await
.challenge(Challenge { questions, options })
.await?
.answers)
}
async fn on_verify_host(&self, host: &str) -> io::Result<bool> {
use distant_core::net::AuthVerifyKind;
self.0
Ok(self
.0
.lock()
.await
.verify(AuthVerifyKind::Host, host.to_string())
.await
.verify(Verification {
kind: VerificationKind::Host,
text: host.to_string(),
})
.await?
.valid)
}
async fn on_banner(&self, text: &str) {
if let Err(x) = self.0.lock().await.info(text.to_string()).await {
if let Err(x) = self
.0
.lock()
.await
.info(Info {
text: text.to_string(),
})
.await
{
error!("ssh on_banner failed: {}", x);
}
}
async fn on_error(&self, text: &str) {
use distant_core::net::AuthErrorKind;
if let Err(x) = self
.0
.lock()
.await
.error(AuthErrorKind::Unknown, text.to_string())
.error(Error {
kind: ErrorKind::Fatal,
text: text.to_string(),
})
.await
{
error!("ssh on_error failed: {}", x);

@ -1,6 +1,6 @@
use crate::config::BindAddress;
use clap::Args;
use distant_core::Map;
use distant_core::net::common::Map;
use serde::{Deserialize, Serialize};
#[derive(Args, Debug, Default, Serialize, Deserialize)]

@ -1,6 +1,6 @@
use super::{AccessControl, CommonConfig, NetworkConfig};
use clap::Args;
use distant_core::Destination;
use distant_core::net::common::Destination;
use serde::{Deserialize, Serialize};
use service_manager::ServiceManagerKind;

@ -1,9 +1,7 @@
use anyhow::Context;
use clap::Args;
use distant_core::{
net::{PortRange, Shutdown},
Host, HostParseError, Map,
};
use distant_core::net::common::{Host, HostParseError, Map, PortRange};
use distant_core::net::server::Shutdown;
use serde::{Deserialize, Serialize};
use std::{
env, fmt,

Loading…
Cancel
Save