From bf5f0b3568743ff3a869971127841a88795057dd Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Wed, 4 Dec 2019 17:14:11 +0100 Subject: [PATCH] Update to tokio 0.2 --- Cargo.toml | 9 ++--- src/dnscrypt_certs.rs | 11 ++++--- src/globals.rs | 4 +-- src/main.rs | 76 +++++++++++++++++++++++-------------------- src/metrics.rs | 23 +++++++------ src/resolver.rs | 9 ++--- 6 files changed, 68 insertions(+), 64 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 735ed3d..fc2428a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,8 @@ coarsetime = "0.1.12" daemonize-simple = "0.1.2" derivative = "1.0.3" dnsstamps = "0.1.3" -env_logger = { version="0.7.1", default-features = false, features = ["humantime"] } -futures-preview = { version = "=0.3.0-alpha.19", features = ["async-await"] } +env_logger = { version = "0.7.1", default-features = false, features = ["humantime"] } +futures = { version = "0.3", features = ["async-await"] } ipext = "0.1.0" jemallocator = "0.3.2" libsodium-sys-stable="1.18.2" @@ -35,12 +35,13 @@ serde = "1.0.103" serde_derive = "1.0.103" serde-big-array = "0.2.0" siphasher = "0.3.1" -tokio = "=0.2.0-alpha.6" +tokio = { version = "0.2", features = ["full"] } toml = "0.5.5" [dependencies.hyper] optional = true -version = "0.13.0-alpha.4" +git = "https://github.com/hyperium/hyper" +branch = "tokio-up" default_features = false [dependencies.prometheus] diff --git a/src/dnscrypt_certs.rs b/src/dnscrypt_certs.rs index 6275c4f..a98f8b0 100644 --- a/src/dnscrypt_certs.rs +++ b/src/dnscrypt_certs.rs @@ -172,7 +172,7 @@ impl DNSCryptEncryptionParamsUpdater { dnscrypt_encryption_params_set: new_params_set.iter().map(|x| (**x).clone()).collect(), }; let state_file = self.globals.state_file.to_path_buf(); - self.globals.runtime.spawn(async move { + self.globals.runtime_handle.spawn(async move { let _ = state.async_save(state_file).await; }); *self.globals.dnscrypt_encryption_params_set.write() = Arc::new(new_params_set); @@ -180,11 +180,12 @@ impl DNSCryptEncryptionParamsUpdater { } pub async fn run(self) { - let mut fut_interval = tokio::timer::Interval::new_interval( - std::time::Duration::from_secs(u64::from(DNSCRYPT_CERTS_RENEWAL)), - ); + let mut fut_interval = tokio::time::interval(std::time::Duration::from_secs(u64::from( + DNSCRYPT_CERTS_RENEWAL, + ))); let fut = async move { - while fut_interval.next().await.is_some() { + loop { + fut_interval.tick().await; self.update(); debug!("New cert issued"); } diff --git a/src/globals.rs b/src/globals.rs index 4d7eb36..16871fd 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -13,13 +13,13 @@ use std::path::PathBuf; use std::sync::atomic::AtomicU32; use std::sync::Arc; use std::time::Duration; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; use tokio::sync::oneshot; #[derive(Clone, Derivative)] #[derivative(Debug)] pub struct Globals { - pub runtime: Arc, + pub runtime_handle: Handle, pub state_file: PathBuf, pub dnscrypt_encryption_params_set: Arc>>>>, pub provider_name: String, diff --git a/src/main.rs b/src/main.rs index 007b0be..3e18e85 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,7 +71,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::prelude::*; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; use tokio::sync::oneshot; #[derive(Debug)] @@ -238,21 +238,20 @@ async fn tls_proxy( SocketAddr::V6(_) => net2::TcpBuilder::new_v6()?.to_tcp_stream()?, }, }; - let mut ext_socket = - TcpStream::connect_std(std_socket, tls_upstream_addr, &Default::default()).await?; + let mut ext_socket = TcpStream::connect_std(std_socket, tls_upstream_addr).await?; let (mut erh, mut ewh) = ext_socket.split(); let (mut rh, mut wh) = client_connection.split(); ewh.write_all(&binlen).await?; - let fut_proxy_1 = rh.copy(&mut ewh); - let fut_proxy_2 = erh.copy(&mut wh); + let fut_proxy_1 = tokio::io::copy(&mut rh, &mut ewh); + let fut_proxy_2 = tokio::io::copy(&mut erh, &mut wh); match join!(fut_proxy_1, fut_proxy_2) { (Ok(_), Ok(_)) => Ok(()), _ => bail!("TLS proxy error"), } } -async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Result<(), Error> { - let runtime = globals.runtime.clone(); +async fn tcp_acceptor(globals: Arc, mut tcp_listener: TcpListener) -> Result<(), Error> { + let runtime_handle = globals.runtime_handle.clone(); let mut tcp_listener = tcp_listener.incoming(); let timeout = globals.tcp_timeout; let concurrent_connections = globals.tcp_concurrent_connections.clone(); @@ -301,8 +300,8 @@ async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Resul Ok(()) }; let fut_abort = rx; - let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout); - runtime.spawn(fut_all.map(move |_| { + let fut_all = tokio::time::timeout(timeout, future::select(fut.boxed(), fut_abort)); + runtime_handle.spawn(fut_all.map(move |_| { let _count = concurrent_connections.fetch_sub(1, Ordering::Relaxed); #[cfg(feature = "metrics")] varz.inflight_tcp_queries @@ -317,7 +316,7 @@ async fn udp_acceptor( globals: Arc, net_udp_socket: std::net::UdpSocket, ) -> Result<(), Error> { - let runtime = globals.runtime.clone(); + let runtime_handle = globals.runtime_handle.clone(); let mut tokio_udp_socket = UdpSocket::try_from(net_udp_socket.try_clone()?)?; let timeout = globals.udp_timeout; let concurrent_connections = globals.udp_concurrent_connections.clone(); @@ -356,8 +355,8 @@ async fn udp_acceptor( let concurrent_connections = concurrent_connections.clone(); let fut = handle_client_query(globals, client_ctx, packet); let fut_abort = rx; - let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout); - runtime.spawn(fut_all.map(move |_| { + let fut_all = tokio::time::timeout(timeout, future::select(fut.boxed(), fut_abort)); + runtime_handle.spawn(fut_all.map(move |_| { let _count = concurrent_connections.fetch_sub(1, Ordering::Relaxed); #[cfg(feature = "metrics")] varz.inflight_udp_queries @@ -368,22 +367,27 @@ async fn udp_acceptor( async fn start( globals: Arc, - runtime: Arc, - listeners: Vec<(TcpListener, std::net::UdpSocket)>, + runtime_handle: Handle, + listeners: Vec<(std::net::TcpListener, std::net::UdpSocket)>, ) -> Result<(), Error> { for listener in listeners { - runtime.spawn(tcp_acceptor(globals.clone(), listener.0).map(|_| {})); - runtime.spawn(udp_acceptor(globals.clone(), listener.1).map(|_| {})); + let tcp_listener_str = format!("{:?}", listener.0); + let tokio_tcp_listener = match TcpListener::from_std(listener.0) { + Ok(tcp_listener) => tcp_listener, + Err(e) => bail!("{}/TCP: {}", tcp_listener_str, e), + }; + runtime_handle.spawn(tcp_acceptor(globals.clone(), tokio_tcp_listener).map(|_| {})); + runtime_handle.spawn(udp_acceptor(globals.clone(), listener.1).map(|_| {})); } Ok(()) } fn bind_listeners( listen_addrs: &[SocketAddr], -) -> Result, Error> { +) -> Result, Error> { let mut sockets = Vec::with_capacity(listen_addrs.len()); for listen_addr in listen_addrs { - let std_socket = match listen_addr { + let tcp_listener = match listen_addr { SocketAddr::V4(_) => net2::TcpBuilder::new_v4()? .reuse_address(true)? .bind(&listen_addr)? @@ -394,10 +398,6 @@ fn bind_listeners( .bind(&listen_addr)? .listen(1024)?, }; - let tcp_listener = match TcpListener::from_std(std_socket, &Default::default()) { - Ok(tcp_listener) => tcp_listener, - Err(e) => bail!("{}/TCP: {}", listen_addr, e), - }; let std_socket = match listen_addr { SocketAddr::V4(_) => net2::UdpBuilder::new_v4()? .reuse_address(true)? @@ -495,8 +495,10 @@ fn main() -> Result<(), Error> { let external_addr = config.external_addr.map(|addr| SocketAddr::new(addr, 0)); let mut runtime_builder = tokio::runtime::Builder::new(); - runtime_builder.name_prefix("encrypted-dns-"); - let runtime = Arc::new(runtime_builder.build()?); + runtime_builder.enable_all(); + runtime_builder.threaded_scheduler(); + runtime_builder.thread_name("encrypted-dns-"); + let mut runtime = runtime_builder.build()?; let listen_addrs: Vec<_> = config.listen_addrs.iter().map(|x| x.local).collect(); let listeners = bind_listeners(&listen_addrs) @@ -627,9 +629,9 @@ fn main() -> Result<(), Error> { anonymized_dns.blacklisted_ips, ), }; - + let runtime_handle = runtime.handle(); let globals = Arc::new(Globals { - runtime: runtime.clone(), + runtime_handle: runtime_handle.clone(), state_file: state_file.to_path_buf(), dnscrypt_encryption_params_set: Arc::new(RwLock::new(Arc::new( dnscrypt_encryption_params_set, @@ -671,18 +673,22 @@ fn main() -> Result<(), Error> { #[cfg(feature = "metrics")] { if let Some(metrics_config) = config.metrics { - runtime.spawn( - metrics::prometheus_service(globals.varz.clone(), metrics_config, runtime.clone()) - .map_err(|e| { - error!("Unable to start the metrics service: [{}]", e); - std::process::exit(1); - }) - .map(|_| ()), + runtime_handle.spawn( + metrics::prometheus_service( + globals.varz.clone(), + metrics_config, + runtime_handle.clone(), + ) + .map_err(|e| { + error!("Unable to start the metrics service: [{}]", e); + std::process::exit(1); + }) + .map(|_| ()), ); } } - runtime.spawn( - start(globals, runtime.clone(), listeners) + runtime_handle.spawn( + start(globals, runtime_handle.clone(), listeners) .map_err(|e| { error!("Unable to start the service: [{}]", e); std::process::exit(1); diff --git a/src/metrics.rs b/src/metrics.rs index bcad8bc..5df82e7 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -14,8 +14,7 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use tokio::net::TcpListener; -use tokio::prelude::*; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; const METRICS_CONNECTION_TIMEOUT_SECS: u64 = 10; const METRICS_MAX_CONCURRENT_CONNECTIONS: u32 = 2; @@ -50,7 +49,7 @@ async fn handle_client_connection( pub async fn prometheus_service( varz: Varz, metrics_config: MetricsConfig, - runtime: Arc, + runtime_handle: Handle, ) -> Result<(), Error> { let path = Arc::new(metrics_config.path); let std_socket = match metrics_config.listen_addr { @@ -64,7 +63,7 @@ pub async fn prometheus_service( .bind(&metrics_config.listen_addr)? .listen(1024)?, }; - let mut stream = TcpListener::from_std(std_socket, &Default::default())?; + let mut stream = TcpListener::from_std(std_socket)?; let concurrent_connections = Arc::new(AtomicU32::new(0)); loop { let (client, _client_addr) = stream.accept().await?; @@ -80,14 +79,14 @@ pub async fn prometheus_service( service_fn(move |req| handle_client_connection(req, varz.clone(), path.clone())); let connection = Http::new().serve_connection(client, service); let concurrent_connections = concurrent_connections.clone(); - runtime.spawn( - connection - .timeout(std::time::Duration::from_secs( - METRICS_CONNECTION_TIMEOUT_SECS, - )) - .map(move |_| { - concurrent_connections.fetch_sub(1, Ordering::Relaxed); - }), + runtime_handle.spawn( + tokio::time::timeout( + std::time::Duration::from_secs(METRICS_CONNECTION_TIMEOUT_SECS), + connection, + ) + .map(move |_| { + concurrent_connections.fetch_sub(1, Ordering::Relaxed); + }), ); } Ok(()) diff --git a/src/resolver.rs b/src/resolver.rs index 35fa122..ee230d9 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -29,7 +29,7 @@ pub async fn resolve_udp( ))?, }, }; - let mut ext_socket = UdpSocket::from_std(std_socket, &Default::default())?; + let mut ext_socket = UdpSocket::from_std(std_socket)?; ext_socket.connect(&globals.upstream_addr).await?; dns::set_edns_max_payload_size(&mut packet, DNS_MAX_PACKET_SIZE as u16)?; let mut response; @@ -38,9 +38,7 @@ pub async fn resolve_udp( ext_socket.send(&packet).await?; response = vec![0u8; DNS_MAX_PACKET_SIZE]; dns::set_rcode_servfail(&mut response); - let fut = ext_socket - .recv_from(&mut response[..]) - .timeout(timeout_if_cached); + let fut = tokio::time::timeout(timeout_if_cached, ext_socket.recv_from(&mut response[..])); match fut.await { Ok(Ok((response_len, response_addr))) => { response.truncate(response_len); @@ -78,8 +76,7 @@ pub async fn resolve_tcp( SocketAddr::V6(_) => net2::TcpBuilder::new_v6()?.to_tcp_stream()?, }, }; - let mut ext_socket = - TcpStream::connect_std(std_socket, &globals.upstream_addr, &Default::default()).await?; + let mut ext_socket = TcpStream::connect_std(std_socket, &globals.upstream_addr).await?; ext_socket.set_nodelay(true)?; let mut binlen = [0u8, 0]; BigEndian::write_u16(&mut binlen[..], packet.len() as u16);