From 4874e30f3ce9b186ac7cd427cba4a8542bd5048e Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Sat, 22 Jul 2023 16:25:54 +0300 Subject: [PATCH] melib: add smtp-trace feature If it's enabled, every read/write in an SMTP transaction will be logged on TRACE level. --- melib/Cargo.toml | 1 + melib/src/backends/imap/connection.rs | 10 +- melib/src/backends/nntp/connection.rs | 6 +- melib/src/smtp.rs | 36 ++++-- melib/src/utils/connections.rs | 174 ++++++++++++++++++++------ 5 files changed, 175 insertions(+), 52 deletions(-) diff --git a/melib/Cargo.toml b/melib/Cargo.toml index 2ccac383..cb83e945 100644 --- a/melib/Cargo.toml +++ b/melib/Cargo.toml @@ -83,6 +83,7 @@ maildir_backend = ["notify"] mbox_backend = ["notify"] notmuch_backend = [] smtp = ["tls", "base64"] +smtp-trace = ["smtp"] sqlite3 = ["rusqlite", ] tls = ["native-tls"] unicode_algorithms = ["unicode-segmentation"] diff --git a/melib/src/backends/imap/connection.rs b/melib/src/backends/imap/connection.rs index 21275848..193b5ccc 100644 --- a/melib/src/backends/imap/connection.rs +++ b/melib/src/backends/imap/connection.rs @@ -179,7 +179,7 @@ impl ImapStream { let addr = lookup_ipv4(path, server_conf.server_port)?; - let mut socket = AsyncWrapper::new(Connection::Tcp( + let mut socket = AsyncWrapper::new(Connection::new_tcp( if let Some(timeout) = server_conf.timeout { TcpStream::connect_timeout(&addr, timeout)? } else { @@ -271,14 +271,14 @@ impl ImapStream { } } } - AsyncWrapper::new(Connection::Tls(conn_result.chain_err_summary(|| { - format!("Could not initiate TLS negotiation to {}.", path) - })?)) + AsyncWrapper::new(Connection::new_tls(conn_result.chain_err_summary( + || format!("Could not initiate TLS negotiation to {}.", path), + )?)) .chain_err_summary(|| format!("Could not initiate TLS negotiation to {}.", path))? } } else { let addr = lookup_ipv4(path, server_conf.server_port)?; - AsyncWrapper::new(Connection::Tcp( + AsyncWrapper::new(Connection::new_tcp( if let Some(timeout) = server_conf.timeout { TcpStream::connect_timeout(&addr, timeout)? } else { diff --git a/melib/src/backends/nntp/connection.rs b/melib/src/backends/nntp/connection.rs index 4b773149..872f6dc3 100644 --- a/melib/src/backends/nntp/connection.rs +++ b/melib/src/backends/nntp/connection.rs @@ -79,7 +79,7 @@ impl NntpStream { let stream = { let addr = lookup_ipv4(path, server_conf.server_port)?; - AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout( + AsyncWrapper::new(Connection::new_tcp(TcpStream::connect_timeout( &addr, std::time::Duration::new(16, 0), )?))? @@ -170,8 +170,8 @@ impl NntpStream { } } } - ret.stream = - AsyncWrapper::new(Connection::Tls(conn_result?)).chain_err_summary(|| { + ret.stream = AsyncWrapper::new(Connection::new_tls(conn_result?)) + .chain_err_summary(|| { format!("Could not initiate TLS negotiation to {}.", path) })?; } diff --git a/melib/src/smtp.rs b/melib/src/smtp.rs index c933f08e..61462e6a 100644 --- a/melib/src/smtp.rs +++ b/melib/src/smtp.rs @@ -262,10 +262,16 @@ impl SmtpConnection { let connector = connector.build()?; let addr = lookup_ipv4(path, server_conf.port)?; - let mut socket = AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(4, 0), - )?))?; + let mut socket = { + let conn = Connection::new_tcp(TcpStream::connect_timeout( + &addr, + std::time::Duration::new(4, 0), + )?); + #[cfg(feature = "smtp-trace")] + let conn = conn.trace(true); + + AsyncWrapper::new(conn)? + }; let pre_ehlo_extensions_reply = read_lines( &mut socket, &mut res, @@ -315,6 +321,8 @@ impl SmtpConnection { let mut ret = { let socket = socket.into_inner()?; + #[cfg(feature = "smtp-trace")] + let socket = socket.trace(false); let _path = path.clone(); socket.set_nonblocking(false)?; @@ -340,17 +348,27 @@ impl SmtpConnection { } } */ - AsyncWrapper::new(Connection::Tls(conn))? + AsyncWrapper::new({ + let conn = Connection::new_tls(conn); + #[cfg(feature = "smtp-trace")] + let conn = conn.trace(true); + conn + })? }; ret.write_all(b"EHLO meli.delivery\r\n").await?; ret } SmtpSecurity::None => { let addr = lookup_ipv4(path, server_conf.port)?; - let mut ret = AsyncWrapper::new(Connection::Tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(4, 0), - )?))?; + let mut ret = AsyncWrapper::new({ + let conn = Connection::new_tcp(TcpStream::connect_timeout( + &addr, + std::time::Duration::new(4, 0), + )?); + #[cfg(feature = "smtp-trace")] + let conn = conn.trace(true); + conn + })?; res.clear(); let reply = read_lines( &mut ret, diff --git a/melib/src/utils/connections.rs b/melib/src/utils/connections.rs index 77d6b9da..6313daab 100644 --- a/melib/src/utils/connections.rs +++ b/melib/src/utils/connections.rs @@ -40,13 +40,23 @@ use libc::{self, c_int, c_void}; #[derive(Debug)] pub enum Connection { - Tcp(std::net::TcpStream), - Fd(std::os::unix::io::RawFd), + Tcp { + inner: std::net::TcpStream, + trace: bool, + }, + Fd { + inner: std::os::unix::io::RawFd, + trace: bool, + }, #[cfg(feature = "tls")] - Tls(native_tls::TlsStream), + Tls { + inner: native_tls::TlsStream, + trace: bool, + }, #[cfg(feature = "deflate_compression")] Deflate { inner: DeflateEncoder>>, + trace: bool, }, } @@ -65,26 +75,77 @@ macro_rules! syscall { } impl Connection { - pub const IO_BUF_SIZE: usize = 64 * 1024; #[cfg(feature = "deflate_compression")] - pub fn deflate(self) -> Self { + pub const IO_BUF_SIZE: usize = 64 * 1024; + + #[cfg(feature = "deflate_compression")] + pub fn deflate(mut self) -> Self { + let trace = self.is_trace_enabled(); + self.set_trace(false); Self::Deflate { inner: DeflateEncoder::new( DeflateDecoder::new_with_buf(Box::new(self), vec![0; Self::IO_BUF_SIZE]), Compression::default(), ), + trace, + } + } + + #[cfg(feature = "tls")] + pub fn new_tls(mut inner: native_tls::TlsStream) -> Self { + let trace = inner.get_ref().is_trace_enabled(); + if trace { + inner.get_mut().set_trace(false); + } + Self::Tls { inner, trace } + } + + pub fn new_tcp(inner: std::net::TcpStream) -> Self { + Self::Tcp { + inner, + trace: false, + } + } + + pub fn trace(mut self, val: bool) -> Self { + match self { + Tcp { ref mut trace, .. } => *trace = val, + #[cfg(feature = "tls")] + Tls { ref mut trace, .. } => *trace = val, + Fd { ref mut trace, .. } => { + *trace = val; + } + #[cfg(feature = "deflate_compression")] + Deflate { ref mut trace, .. } => *trace = val, + } + self + } + + pub fn set_trace(&mut self, val: bool) { + match self { + Tcp { ref mut trace, .. } => *trace = val, + #[cfg(feature = "tls")] + Tls { ref mut trace, .. } => *trace = val, + Fd { ref mut trace, .. } => { + *trace = val; + } + #[cfg(feature = "deflate_compression")] + Deflate { ref mut trace, .. } => *trace = val, } } pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> { + if self.is_trace_enabled() { + log::trace!("{:?} set_nonblocking({:?})", self, nonblocking); + } match self { - Tcp(ref t) => t.set_nonblocking(nonblocking), + Tcp { ref inner, .. } => inner.set_nonblocking(nonblocking), #[cfg(feature = "tls")] - Tls(ref t) => t.get_ref().set_nonblocking(nonblocking), - Fd(fd) => { + Tls { ref inner, .. } => inner.get_ref().set_nonblocking(nonblocking), + Fd { inner, .. } => { // [ref:VERIFY] nix::fcntl::fcntl( - *fd, + *inner, nix::fcntl::FcntlArg::F_SETFL(if nonblocking { nix::fcntl::OFlag::O_NONBLOCK } else { @@ -100,29 +161,38 @@ impl Connection { } pub fn set_read_timeout(&self, dur: Option) -> std::io::Result<()> { + if self.is_trace_enabled() { + log::trace!("{:?} set_read_timeout({:?})", self, dur); + } match self { - Tcp(ref t) => t.set_read_timeout(dur), + Tcp { ref inner, .. } => inner.set_read_timeout(dur), #[cfg(feature = "tls")] - Tls(ref t) => t.get_ref().set_read_timeout(dur), - Fd(_) => Ok(()), + Tls { ref inner, .. } => inner.get_ref().set_read_timeout(dur), + Fd { .. } => Ok(()), #[cfg(feature = "deflate_compression")] Deflate { ref inner, .. } => inner.get_ref().get_ref().set_read_timeout(dur), } } pub fn set_write_timeout(&self, dur: Option) -> std::io::Result<()> { + if self.is_trace_enabled() { + log::trace!("{:?} set_write_timeout({:?})", self, dur); + } match self { - Tcp(ref t) => t.set_write_timeout(dur), + Tcp { ref inner, .. } => inner.set_write_timeout(dur), #[cfg(feature = "tls")] - Tls(ref t) => t.get_ref().set_write_timeout(dur), - Fd(_) => Ok(()), + Tls { ref inner, .. } => inner.get_ref().set_write_timeout(dur), + Fd { .. } => Ok(()), #[cfg(feature = "deflate_compression")] Deflate { ref inner, .. } => inner.get_ref().get_ref().set_write_timeout(dur), } } pub fn keepalive(&self) -> std::io::Result> { - if let Fd(_) = self { + if self.is_trace_enabled() { + log::trace!("{:?} keepalive()", self); + } + if matches!(self, Fd { .. }) { return Ok(None); } unsafe { @@ -136,7 +206,10 @@ impl Connection { } pub fn set_keepalive(&self, keepalive: Option) -> std::io::Result<()> { - if let Fd(_) = self { + if self.is_trace_enabled() { + log::trace!("{:?} set_keepalive({:?})", self, keepalive); + } + if matches!(self, Fd { .. }) { return Ok(()); } unsafe { @@ -181,44 +254,75 @@ impl Connection { assert_eq!(len as usize, std::mem::size_of::()); Ok(slot) } + + fn is_trace_enabled(&self) -> bool { + match self { + Fd { trace, .. } | Tcp { trace, .. } => *trace, + #[cfg(feature = "tls")] + Tls { trace, .. } => *trace, + #[cfg(feature = "deflate_compression")] + Deflate { trace, .. } => *trace, + } + } } impl Drop for Connection { fn drop(&mut self) { - if let Fd(fd) = self { - let _ = nix::unistd::close(*fd); + if let Fd { ref inner, .. } = self { + let _ = nix::unistd::close(*inner); } } } impl std::io::Read for Connection { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - match self { - Tcp(ref mut t) => t.read(buf), + let res = match self { + Tcp { ref mut inner, .. } => inner.read(buf), #[cfg(feature = "tls")] - Tls(ref mut t) => t.read(buf), - Fd(f) => { + Tls { ref mut inner, .. } => inner.read(buf), + Fd { ref inner, .. } => { use std::os::unix::io::{FromRawFd, IntoRawFd}; - let mut f = unsafe { std::fs::File::from_raw_fd(*f) }; + let mut f = unsafe { std::fs::File::from_raw_fd(*inner) }; let ret = f.read(buf); let _ = f.into_raw_fd(); ret } #[cfg(feature = "deflate_compression")] Deflate { ref mut inner, .. } => inner.read(buf), + }; + if self.is_trace_enabled() { + if let Ok(len) = &res { + log::trace!( + "{:?} read {:?} bytes:{:?}", + self, + len, + String::from_utf8_lossy(&buf[..*len]) + ); + } else { + log::trace!("{:?} could not read {:?}", self, &res); + } } + res } } impl std::io::Write for Connection { fn write(&mut self, buf: &[u8]) -> std::io::Result { + if self.is_trace_enabled() { + log::trace!( + "{:?} writing {} bytes:{:?}", + self, + buf.len(), + String::from_utf8_lossy(buf) + ); + } match self { - Tcp(ref mut t) => t.write(buf), + Tcp { ref mut inner, .. } => inner.write(buf), #[cfg(feature = "tls")] - Tls(ref mut t) => t.write(buf), - Fd(f) => { + Tls { ref mut inner, .. } => inner.write(buf), + Fd { ref inner, .. } => { use std::os::unix::io::{FromRawFd, IntoRawFd}; - let mut f = unsafe { std::fs::File::from_raw_fd(*f) }; + let mut f = unsafe { std::fs::File::from_raw_fd(*inner) }; let ret = f.write(buf); let _ = f.into_raw_fd(); ret @@ -230,12 +334,12 @@ impl std::io::Write for Connection { fn flush(&mut self) -> std::io::Result<()> { match self { - Tcp(ref mut t) => t.flush(), + Tcp { ref mut inner, .. } => inner.flush(), #[cfg(feature = "tls")] - Tls(ref mut t) => t.flush(), - Fd(f) => { + Tls { ref mut inner, .. } => inner.flush(), + Fd { ref inner, .. } => { use std::os::unix::io::{FromRawFd, IntoRawFd}; - let mut f = unsafe { std::fs::File::from_raw_fd(*f) }; + let mut f = unsafe { std::fs::File::from_raw_fd(*inner) }; let ret = f.flush(); let _ = f.into_raw_fd(); ret @@ -249,10 +353,10 @@ impl std::io::Write for Connection { impl std::os::unix::io::AsRawFd for Connection { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { match self { - Tcp(ref t) => t.as_raw_fd(), + Tcp { ref inner, .. } => inner.as_raw_fd(), #[cfg(feature = "tls")] - Tls(ref t) => t.get_ref().as_raw_fd(), - Fd(f) => *f, + Tls { ref inner, .. } => inner.get_ref().as_raw_fd(), + Fd { ref inner, .. } => *inner, #[cfg(feature = "deflate_compression")] Deflate { ref inner, .. } => inner.get_ref().get_ref().as_raw_fd(), }