melib: add smtp-trace feature

If it's enabled, every read/write in an SMTP transaction will be logged
on TRACE level.
This commit is contained in:
Manos Pitsidianakis 2023-07-22 16:25:54 +03:00
parent 073d43b9b8
commit 4874e30f3c
No known key found for this signature in database
GPG Key ID: 7729C7707F7E09D0
5 changed files with 175 additions and 52 deletions

View File

@ -83,6 +83,7 @@ maildir_backend = ["notify"]
mbox_backend = ["notify"]
notmuch_backend = []
smtp = ["tls", "base64"]
smtp-trace = ["smtp"]
sqlite3 = ["rusqlite", ]
tls = ["native-tls"]
unicode_algorithms = ["unicode-segmentation"]

View File

@ -179,7 +179,7 @@ impl ImapStream {
let addr = lookup_ipv4(path, server_conf.server_port)?;
let mut socket = AsyncWrapper::new(Connection::Tcp(
let mut socket = AsyncWrapper::new(Connection::new_tcp(
if let Some(timeout) = server_conf.timeout {
TcpStream::connect_timeout(&addr, timeout)?
} else {
@ -271,14 +271,14 @@ impl ImapStream {
}
}
}
AsyncWrapper::new(Connection::Tls(conn_result.chain_err_summary(|| {
format!("Could not initiate TLS negotiation to {}.", path)
})?))
AsyncWrapper::new(Connection::new_tls(conn_result.chain_err_summary(
|| format!("Could not initiate TLS negotiation to {}.", path),
)?))
.chain_err_summary(|| format!("Could not initiate TLS negotiation to {}.", path))?
}
} else {
let addr = lookup_ipv4(path, server_conf.server_port)?;
AsyncWrapper::new(Connection::Tcp(
AsyncWrapper::new(Connection::new_tcp(
if let Some(timeout) = server_conf.timeout {
TcpStream::connect_timeout(&addr, timeout)?
} else {

View File

@ -79,7 +79,7 @@ impl NntpStream {
let stream = {
let addr = lookup_ipv4(path, server_conf.server_port)?;
AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout(
AsyncWrapper::new(Connection::new_tcp(TcpStream::connect_timeout(
&addr,
std::time::Duration::new(16, 0),
)?))?
@ -170,8 +170,8 @@ impl NntpStream {
}
}
}
ret.stream =
AsyncWrapper::new(Connection::Tls(conn_result?)).chain_err_summary(|| {
ret.stream = AsyncWrapper::new(Connection::new_tls(conn_result?))
.chain_err_summary(|| {
format!("Could not initiate TLS negotiation to {}.", path)
})?;
}

View File

@ -262,10 +262,16 @@ impl SmtpConnection {
let connector = connector.build()?;
let addr = lookup_ipv4(path, server_conf.port)?;
let mut socket = AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout(
&addr,
std::time::Duration::new(4, 0),
)?))?;
let mut socket = {
let conn = Connection::new_tcp(TcpStream::connect_timeout(
&addr,
std::time::Duration::new(4, 0),
)?);
#[cfg(feature = "smtp-trace")]
let conn = conn.trace(true);
AsyncWrapper::new(conn)?
};
let pre_ehlo_extensions_reply = read_lines(
&mut socket,
&mut res,
@ -315,6 +321,8 @@ impl SmtpConnection {
let mut ret = {
let socket = socket.into_inner()?;
#[cfg(feature = "smtp-trace")]
let socket = socket.trace(false);
let _path = path.clone();
socket.set_nonblocking(false)?;
@ -340,17 +348,27 @@ impl SmtpConnection {
}
}
*/
AsyncWrapper::new(Connection::Tls(conn))?
AsyncWrapper::new({
let conn = Connection::new_tls(conn);
#[cfg(feature = "smtp-trace")]
let conn = conn.trace(true);
conn
})?
};
ret.write_all(b"EHLO meli.delivery\r\n").await?;
ret
}
SmtpSecurity::None => {
let addr = lookup_ipv4(path, server_conf.port)?;
let mut ret = AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout(
&addr,
std::time::Duration::new(4, 0),
)?))?;
let mut ret = AsyncWrapper::new({
let conn = Connection::new_tcp(TcpStream::connect_timeout(
&addr,
std::time::Duration::new(4, 0),
)?);
#[cfg(feature = "smtp-trace")]
let conn = conn.trace(true);
conn
})?;
res.clear();
let reply = read_lines(
&mut ret,

View File

@ -40,13 +40,23 @@ use libc::{self, c_int, c_void};
#[derive(Debug)]
pub enum Connection {
Tcp(std::net::TcpStream),
Fd(std::os::unix::io::RawFd),
Tcp {
inner: std::net::TcpStream,
trace: bool,
},
Fd {
inner: std::os::unix::io::RawFd,
trace: bool,
},
#[cfg(feature = "tls")]
Tls(native_tls::TlsStream<Self>),
Tls {
inner: native_tls::TlsStream<Self>,
trace: bool,
},
#[cfg(feature = "deflate_compression")]
Deflate {
inner: DeflateEncoder<DeflateDecoder<Box<Self>>>,
trace: bool,
},
}
@ -65,26 +75,77 @@ macro_rules! syscall {
}
impl Connection {
pub const IO_BUF_SIZE: usize = 64 * 1024;
#[cfg(feature = "deflate_compression")]
pub fn deflate(self) -> Self {
pub const IO_BUF_SIZE: usize = 64 * 1024;
#[cfg(feature = "deflate_compression")]
pub fn deflate(mut self) -> Self {
let trace = self.is_trace_enabled();
self.set_trace(false);
Self::Deflate {
inner: DeflateEncoder::new(
DeflateDecoder::new_with_buf(Box::new(self), vec![0; Self::IO_BUF_SIZE]),
Compression::default(),
),
trace,
}
}
#[cfg(feature = "tls")]
pub fn new_tls(mut inner: native_tls::TlsStream<Self>) -> Self {
let trace = inner.get_ref().is_trace_enabled();
if trace {
inner.get_mut().set_trace(false);
}
Self::Tls { inner, trace }
}
pub fn new_tcp(inner: std::net::TcpStream) -> Self {
Self::Tcp {
inner,
trace: false,
}
}
pub fn trace(mut self, val: bool) -> Self {
match self {
Tcp { ref mut trace, .. } => *trace = val,
#[cfg(feature = "tls")]
Tls { ref mut trace, .. } => *trace = val,
Fd { ref mut trace, .. } => {
*trace = val;
}
#[cfg(feature = "deflate_compression")]
Deflate { ref mut trace, .. } => *trace = val,
}
self
}
pub fn set_trace(&mut self, val: bool) {
match self {
Tcp { ref mut trace, .. } => *trace = val,
#[cfg(feature = "tls")]
Tls { ref mut trace, .. } => *trace = val,
Fd { ref mut trace, .. } => {
*trace = val;
}
#[cfg(feature = "deflate_compression")]
Deflate { ref mut trace, .. } => *trace = val,
}
}
pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
if self.is_trace_enabled() {
log::trace!("{:?} set_nonblocking({:?})", self, nonblocking);
}
match self {
Tcp(ref t) => t.set_nonblocking(nonblocking),
Tcp { ref inner, .. } => inner.set_nonblocking(nonblocking),
#[cfg(feature = "tls")]
Tls(ref t) => t.get_ref().set_nonblocking(nonblocking),
Fd(fd) => {
Tls { ref inner, .. } => inner.get_ref().set_nonblocking(nonblocking),
Fd { inner, .. } => {
// [ref:VERIFY]
nix::fcntl::fcntl(
*fd,
*inner,
nix::fcntl::FcntlArg::F_SETFL(if nonblocking {
nix::fcntl::OFlag::O_NONBLOCK
} else {
@ -100,29 +161,38 @@ impl Connection {
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
if self.is_trace_enabled() {
log::trace!("{:?} set_read_timeout({:?})", self, dur);
}
match self {
Tcp(ref t) => t.set_read_timeout(dur),
Tcp { ref inner, .. } => inner.set_read_timeout(dur),
#[cfg(feature = "tls")]
Tls(ref t) => t.get_ref().set_read_timeout(dur),
Fd(_) => Ok(()),
Tls { ref inner, .. } => inner.get_ref().set_read_timeout(dur),
Fd { .. } => Ok(()),
#[cfg(feature = "deflate_compression")]
Deflate { ref inner, .. } => inner.get_ref().get_ref().set_read_timeout(dur),
}
}
pub fn set_write_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
if self.is_trace_enabled() {
log::trace!("{:?} set_write_timeout({:?})", self, dur);
}
match self {
Tcp(ref t) => t.set_write_timeout(dur),
Tcp { ref inner, .. } => inner.set_write_timeout(dur),
#[cfg(feature = "tls")]
Tls(ref t) => t.get_ref().set_write_timeout(dur),
Fd(_) => Ok(()),
Tls { ref inner, .. } => inner.get_ref().set_write_timeout(dur),
Fd { .. } => Ok(()),
#[cfg(feature = "deflate_compression")]
Deflate { ref inner, .. } => inner.get_ref().get_ref().set_write_timeout(dur),
}
}
pub fn keepalive(&self) -> std::io::Result<Option<Duration>> {
if let Fd(_) = self {
if self.is_trace_enabled() {
log::trace!("{:?} keepalive()", self);
}
if matches!(self, Fd { .. }) {
return Ok(None);
}
unsafe {
@ -136,7 +206,10 @@ impl Connection {
}
pub fn set_keepalive(&self, keepalive: Option<Duration>) -> std::io::Result<()> {
if let Fd(_) = self {
if self.is_trace_enabled() {
log::trace!("{:?} set_keepalive({:?})", self, keepalive);
}
if matches!(self, Fd { .. }) {
return Ok(());
}
unsafe {
@ -181,44 +254,75 @@ impl Connection {
assert_eq!(len as usize, std::mem::size_of::<T>());
Ok(slot)
}
fn is_trace_enabled(&self) -> bool {
match self {
Fd { trace, .. } | Tcp { trace, .. } => *trace,
#[cfg(feature = "tls")]
Tls { trace, .. } => *trace,
#[cfg(feature = "deflate_compression")]
Deflate { trace, .. } => *trace,
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
if let Fd(fd) = self {
let _ = nix::unistd::close(*fd);
if let Fd { ref inner, .. } = self {
let _ = nix::unistd::close(*inner);
}
}
}
impl std::io::Read for Connection {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Tcp(ref mut t) => t.read(buf),
let res = match self {
Tcp { ref mut inner, .. } => inner.read(buf),
#[cfg(feature = "tls")]
Tls(ref mut t) => t.read(buf),
Fd(f) => {
Tls { ref mut inner, .. } => inner.read(buf),
Fd { ref inner, .. } => {
use std::os::unix::io::{FromRawFd, IntoRawFd};
let mut f = unsafe { std::fs::File::from_raw_fd(*f) };
let mut f = unsafe { std::fs::File::from_raw_fd(*inner) };
let ret = f.read(buf);
let _ = f.into_raw_fd();
ret
}
#[cfg(feature = "deflate_compression")]
Deflate { ref mut inner, .. } => inner.read(buf),
};
if self.is_trace_enabled() {
if let Ok(len) = &res {
log::trace!(
"{:?} read {:?} bytes:{:?}",
self,
len,
String::from_utf8_lossy(&buf[..*len])
);
} else {
log::trace!("{:?} could not read {:?}", self, &res);
}
}
res
}
}
impl std::io::Write for Connection {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if self.is_trace_enabled() {
log::trace!(
"{:?} writing {} bytes:{:?}",
self,
buf.len(),
String::from_utf8_lossy(buf)
);
}
match self {
Tcp(ref mut t) => t.write(buf),
Tcp { ref mut inner, .. } => inner.write(buf),
#[cfg(feature = "tls")]
Tls(ref mut t) => t.write(buf),
Fd(f) => {
Tls { ref mut inner, .. } => inner.write(buf),
Fd { ref inner, .. } => {
use std::os::unix::io::{FromRawFd, IntoRawFd};
let mut f = unsafe { std::fs::File::from_raw_fd(*f) };
let mut f = unsafe { std::fs::File::from_raw_fd(*inner) };
let ret = f.write(buf);
let _ = f.into_raw_fd();
ret
@ -230,12 +334,12 @@ impl std::io::Write for Connection {
fn flush(&mut self) -> std::io::Result<()> {
match self {
Tcp(ref mut t) => t.flush(),
Tcp { ref mut inner, .. } => inner.flush(),
#[cfg(feature = "tls")]
Tls(ref mut t) => t.flush(),
Fd(f) => {
Tls { ref mut inner, .. } => inner.flush(),
Fd { ref inner, .. } => {
use std::os::unix::io::{FromRawFd, IntoRawFd};
let mut f = unsafe { std::fs::File::from_raw_fd(*f) };
let mut f = unsafe { std::fs::File::from_raw_fd(*inner) };
let ret = f.flush();
let _ = f.into_raw_fd();
ret
@ -249,10 +353,10 @@ impl std::io::Write for Connection {
impl std::os::unix::io::AsRawFd for Connection {
fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
match self {
Tcp(ref t) => t.as_raw_fd(),
Tcp { ref inner, .. } => inner.as_raw_fd(),
#[cfg(feature = "tls")]
Tls(ref t) => t.get_ref().as_raw_fd(),
Fd(f) => *f,
Tls { ref inner, .. } => inner.get_ref().as_raw_fd(),
Fd { ref inner, .. } => *inner,
#[cfg(feature = "deflate_compression")]
Deflate { ref inner, .. } => inner.get_ref().get_ref().as_raw_fd(),
}