2
0
mirror of https://github.com/jedisct1/encrypted-dns-server synced 2024-11-12 13:10:44 +00:00

Update to tokio 0.2

This commit is contained in:
Frank Denis 2019-12-04 17:14:11 +01:00
parent f96ba4d62c
commit bf5f0b3568
6 changed files with 68 additions and 64 deletions

View File

@ -20,8 +20,8 @@ coarsetime = "0.1.12"
daemonize-simple = "0.1.2"
derivative = "1.0.3"
dnsstamps = "0.1.3"
env_logger = { version="0.7.1", default-features = false, features = ["humantime"] }
futures-preview = { version = "=0.3.0-alpha.19", features = ["async-await"] }
env_logger = { version = "0.7.1", default-features = false, features = ["humantime"] }
futures = { version = "0.3", features = ["async-await"] }
ipext = "0.1.0"
jemallocator = "0.3.2"
libsodium-sys-stable="1.18.2"
@ -35,12 +35,13 @@ serde = "1.0.103"
serde_derive = "1.0.103"
serde-big-array = "0.2.0"
siphasher = "0.3.1"
tokio = "=0.2.0-alpha.6"
tokio = { version = "0.2", features = ["full"] }
toml = "0.5.5"
[dependencies.hyper]
optional = true
version = "0.13.0-alpha.4"
git = "https://github.com/hyperium/hyper"
branch = "tokio-up"
default_features = false
[dependencies.prometheus]

View File

@ -172,7 +172,7 @@ impl DNSCryptEncryptionParamsUpdater {
dnscrypt_encryption_params_set: new_params_set.iter().map(|x| (**x).clone()).collect(),
};
let state_file = self.globals.state_file.to_path_buf();
self.globals.runtime.spawn(async move {
self.globals.runtime_handle.spawn(async move {
let _ = state.async_save(state_file).await;
});
*self.globals.dnscrypt_encryption_params_set.write() = Arc::new(new_params_set);
@ -180,11 +180,12 @@ impl DNSCryptEncryptionParamsUpdater {
}
pub async fn run(self) {
let mut fut_interval = tokio::timer::Interval::new_interval(
std::time::Duration::from_secs(u64::from(DNSCRYPT_CERTS_RENEWAL)),
);
let mut fut_interval = tokio::time::interval(std::time::Duration::from_secs(u64::from(
DNSCRYPT_CERTS_RENEWAL,
)));
let fut = async move {
while fut_interval.next().await.is_some() {
loop {
fut_interval.tick().await;
self.update();
debug!("New cert issued");
}

View File

@ -13,13 +13,13 @@ use std::path::PathBuf;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Runtime;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct Globals {
pub runtime: Arc<Runtime>,
pub runtime_handle: Handle,
pub state_file: PathBuf,
pub dnscrypt_encryption_params_set: Arc<RwLock<Arc<Vec<Arc<DNSCryptEncryptionParams>>>>>,
pub provider_name: String,

View File

@ -71,7 +71,7 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::prelude::*;
use tokio::runtime::Runtime;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
#[derive(Debug)]
@ -238,21 +238,20 @@ async fn tls_proxy(
SocketAddr::V6(_) => net2::TcpBuilder::new_v6()?.to_tcp_stream()?,
},
};
let mut ext_socket =
TcpStream::connect_std(std_socket, tls_upstream_addr, &Default::default()).await?;
let mut ext_socket = TcpStream::connect_std(std_socket, tls_upstream_addr).await?;
let (mut erh, mut ewh) = ext_socket.split();
let (mut rh, mut wh) = client_connection.split();
ewh.write_all(&binlen).await?;
let fut_proxy_1 = rh.copy(&mut ewh);
let fut_proxy_2 = erh.copy(&mut wh);
let fut_proxy_1 = tokio::io::copy(&mut rh, &mut ewh);
let fut_proxy_2 = tokio::io::copy(&mut erh, &mut wh);
match join!(fut_proxy_1, fut_proxy_2) {
(Ok(_), Ok(_)) => Ok(()),
_ => bail!("TLS proxy error"),
}
}
async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Result<(), Error> {
let runtime = globals.runtime.clone();
async fn tcp_acceptor(globals: Arc<Globals>, mut tcp_listener: TcpListener) -> Result<(), Error> {
let runtime_handle = globals.runtime_handle.clone();
let mut tcp_listener = tcp_listener.incoming();
let timeout = globals.tcp_timeout;
let concurrent_connections = globals.tcp_concurrent_connections.clone();
@ -301,8 +300,8 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
Ok(())
};
let fut_abort = rx;
let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout);
runtime.spawn(fut_all.map(move |_| {
let fut_all = tokio::time::timeout(timeout, future::select(fut.boxed(), fut_abort));
runtime_handle.spawn(fut_all.map(move |_| {
let _count = concurrent_connections.fetch_sub(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
varz.inflight_tcp_queries
@ -317,7 +316,7 @@ async fn udp_acceptor(
globals: Arc<Globals>,
net_udp_socket: std::net::UdpSocket,
) -> Result<(), Error> {
let runtime = globals.runtime.clone();
let runtime_handle = globals.runtime_handle.clone();
let mut tokio_udp_socket = UdpSocket::try_from(net_udp_socket.try_clone()?)?;
let timeout = globals.udp_timeout;
let concurrent_connections = globals.udp_concurrent_connections.clone();
@ -356,8 +355,8 @@ async fn udp_acceptor(
let concurrent_connections = concurrent_connections.clone();
let fut = handle_client_query(globals, client_ctx, packet);
let fut_abort = rx;
let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout);
runtime.spawn(fut_all.map(move |_| {
let fut_all = tokio::time::timeout(timeout, future::select(fut.boxed(), fut_abort));
runtime_handle.spawn(fut_all.map(move |_| {
let _count = concurrent_connections.fetch_sub(1, Ordering::Relaxed);
#[cfg(feature = "metrics")]
varz.inflight_udp_queries
@ -368,22 +367,27 @@ async fn udp_acceptor(
async fn start(
globals: Arc<Globals>,
runtime: Arc<Runtime>,
listeners: Vec<(TcpListener, std::net::UdpSocket)>,
runtime_handle: Handle,
listeners: Vec<(std::net::TcpListener, std::net::UdpSocket)>,
) -> Result<(), Error> {
for listener in listeners {
runtime.spawn(tcp_acceptor(globals.clone(), listener.0).map(|_| {}));
runtime.spawn(udp_acceptor(globals.clone(), listener.1).map(|_| {}));
let tcp_listener_str = format!("{:?}", listener.0);
let tokio_tcp_listener = match TcpListener::from_std(listener.0) {
Ok(tcp_listener) => tcp_listener,
Err(e) => bail!("{}/TCP: {}", tcp_listener_str, e),
};
runtime_handle.spawn(tcp_acceptor(globals.clone(), tokio_tcp_listener).map(|_| {}));
runtime_handle.spawn(udp_acceptor(globals.clone(), listener.1).map(|_| {}));
}
Ok(())
}
fn bind_listeners(
listen_addrs: &[SocketAddr],
) -> Result<Vec<(TcpListener, std::net::UdpSocket)>, Error> {
) -> Result<Vec<(std::net::TcpListener, std::net::UdpSocket)>, Error> {
let mut sockets = Vec::with_capacity(listen_addrs.len());
for listen_addr in listen_addrs {
let std_socket = match listen_addr {
let tcp_listener = match listen_addr {
SocketAddr::V4(_) => net2::TcpBuilder::new_v4()?
.reuse_address(true)?
.bind(&listen_addr)?
@ -394,10 +398,6 @@ fn bind_listeners(
.bind(&listen_addr)?
.listen(1024)?,
};
let tcp_listener = match TcpListener::from_std(std_socket, &Default::default()) {
Ok(tcp_listener) => tcp_listener,
Err(e) => bail!("{}/TCP: {}", listen_addr, e),
};
let std_socket = match listen_addr {
SocketAddr::V4(_) => net2::UdpBuilder::new_v4()?
.reuse_address(true)?
@ -495,8 +495,10 @@ fn main() -> Result<(), Error> {
let external_addr = config.external_addr.map(|addr| SocketAddr::new(addr, 0));
let mut runtime_builder = tokio::runtime::Builder::new();
runtime_builder.name_prefix("encrypted-dns-");
let runtime = Arc::new(runtime_builder.build()?);
runtime_builder.enable_all();
runtime_builder.threaded_scheduler();
runtime_builder.thread_name("encrypted-dns-");
let mut runtime = runtime_builder.build()?;
let listen_addrs: Vec<_> = config.listen_addrs.iter().map(|x| x.local).collect();
let listeners = bind_listeners(&listen_addrs)
@ -627,9 +629,9 @@ fn main() -> Result<(), Error> {
anonymized_dns.blacklisted_ips,
),
};
let runtime_handle = runtime.handle();
let globals = Arc::new(Globals {
runtime: runtime.clone(),
runtime_handle: runtime_handle.clone(),
state_file: state_file.to_path_buf(),
dnscrypt_encryption_params_set: Arc::new(RwLock::new(Arc::new(
dnscrypt_encryption_params_set,
@ -671,18 +673,22 @@ fn main() -> Result<(), Error> {
#[cfg(feature = "metrics")]
{
if let Some(metrics_config) = config.metrics {
runtime.spawn(
metrics::prometheus_service(globals.varz.clone(), metrics_config, runtime.clone())
.map_err(|e| {
error!("Unable to start the metrics service: [{}]", e);
std::process::exit(1);
})
.map(|_| ()),
runtime_handle.spawn(
metrics::prometheus_service(
globals.varz.clone(),
metrics_config,
runtime_handle.clone(),
)
.map_err(|e| {
error!("Unable to start the metrics service: [{}]", e);
std::process::exit(1);
})
.map(|_| ()),
);
}
}
runtime.spawn(
start(globals, runtime.clone(), listeners)
runtime_handle.spawn(
start(globals, runtime_handle.clone(), listeners)
.map_err(|e| {
error!("Unable to start the service: [{}]", e);
std::process::exit(1);

View File

@ -14,8 +14,7 @@ use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::prelude::*;
use tokio::runtime::Runtime;
use tokio::runtime::Handle;
const METRICS_CONNECTION_TIMEOUT_SECS: u64 = 10;
const METRICS_MAX_CONCURRENT_CONNECTIONS: u32 = 2;
@ -50,7 +49,7 @@ async fn handle_client_connection(
pub async fn prometheus_service(
varz: Varz,
metrics_config: MetricsConfig,
runtime: Arc<Runtime>,
runtime_handle: Handle,
) -> Result<(), Error> {
let path = Arc::new(metrics_config.path);
let std_socket = match metrics_config.listen_addr {
@ -64,7 +63,7 @@ pub async fn prometheus_service(
.bind(&metrics_config.listen_addr)?
.listen(1024)?,
};
let mut stream = TcpListener::from_std(std_socket, &Default::default())?;
let mut stream = TcpListener::from_std(std_socket)?;
let concurrent_connections = Arc::new(AtomicU32::new(0));
loop {
let (client, _client_addr) = stream.accept().await?;
@ -80,14 +79,14 @@ pub async fn prometheus_service(
service_fn(move |req| handle_client_connection(req, varz.clone(), path.clone()));
let connection = Http::new().serve_connection(client, service);
let concurrent_connections = concurrent_connections.clone();
runtime.spawn(
connection
.timeout(std::time::Duration::from_secs(
METRICS_CONNECTION_TIMEOUT_SECS,
))
.map(move |_| {
concurrent_connections.fetch_sub(1, Ordering::Relaxed);
}),
runtime_handle.spawn(
tokio::time::timeout(
std::time::Duration::from_secs(METRICS_CONNECTION_TIMEOUT_SECS),
connection,
)
.map(move |_| {
concurrent_connections.fetch_sub(1, Ordering::Relaxed);
}),
);
}
Ok(())

View File

@ -29,7 +29,7 @@ pub async fn resolve_udp(
))?,
},
};
let mut ext_socket = UdpSocket::from_std(std_socket, &Default::default())?;
let mut ext_socket = UdpSocket::from_std(std_socket)?;
ext_socket.connect(&globals.upstream_addr).await?;
dns::set_edns_max_payload_size(&mut packet, DNS_MAX_PACKET_SIZE as u16)?;
let mut response;
@ -38,9 +38,7 @@ pub async fn resolve_udp(
ext_socket.send(&packet).await?;
response = vec![0u8; DNS_MAX_PACKET_SIZE];
dns::set_rcode_servfail(&mut response);
let fut = ext_socket
.recv_from(&mut response[..])
.timeout(timeout_if_cached);
let fut = tokio::time::timeout(timeout_if_cached, ext_socket.recv_from(&mut response[..]));
match fut.await {
Ok(Ok((response_len, response_addr))) => {
response.truncate(response_len);
@ -78,8 +76,7 @@ pub async fn resolve_tcp(
SocketAddr::V6(_) => net2::TcpBuilder::new_v6()?.to_tcp_stream()?,
},
};
let mut ext_socket =
TcpStream::connect_std(std_socket, &globals.upstream_addr, &Default::default()).await?;
let mut ext_socket = TcpStream::connect_std(std_socket, &globals.upstream_addr).await?;
ext_socket.set_nodelay(true)?;
let mut binlen = [0u8, 0];
BigEndian::write_u16(&mut binlen[..], packet.len() as u16);