mirror of
https://github.com/jedisct1/encrypted-dns-server
synced 2024-11-12 13:10:44 +00:00
Update to tokio 0.2
This commit is contained in:
parent
f96ba4d62c
commit
bf5f0b3568
@ -20,8 +20,8 @@ coarsetime = "0.1.12"
|
||||
daemonize-simple = "0.1.2"
|
||||
derivative = "1.0.3"
|
||||
dnsstamps = "0.1.3"
|
||||
env_logger = { version="0.7.1", default-features = false, features = ["humantime"] }
|
||||
futures-preview = { version = "=0.3.0-alpha.19", features = ["async-await"] }
|
||||
env_logger = { version = "0.7.1", default-features = false, features = ["humantime"] }
|
||||
futures = { version = "0.3", features = ["async-await"] }
|
||||
ipext = "0.1.0"
|
||||
jemallocator = "0.3.2"
|
||||
libsodium-sys-stable="1.18.2"
|
||||
@ -35,12 +35,13 @@ serde = "1.0.103"
|
||||
serde_derive = "1.0.103"
|
||||
serde-big-array = "0.2.0"
|
||||
siphasher = "0.3.1"
|
||||
tokio = "=0.2.0-alpha.6"
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
toml = "0.5.5"
|
||||
|
||||
[dependencies.hyper]
|
||||
optional = true
|
||||
version = "0.13.0-alpha.4"
|
||||
git = "https://github.com/hyperium/hyper"
|
||||
branch = "tokio-up"
|
||||
default_features = false
|
||||
|
||||
[dependencies.prometheus]
|
||||
|
@ -172,7 +172,7 @@ impl DNSCryptEncryptionParamsUpdater {
|
||||
dnscrypt_encryption_params_set: new_params_set.iter().map(|x| (**x).clone()).collect(),
|
||||
};
|
||||
let state_file = self.globals.state_file.to_path_buf();
|
||||
self.globals.runtime.spawn(async move {
|
||||
self.globals.runtime_handle.spawn(async move {
|
||||
let _ = state.async_save(state_file).await;
|
||||
});
|
||||
*self.globals.dnscrypt_encryption_params_set.write() = Arc::new(new_params_set);
|
||||
@ -180,11 +180,12 @@ impl DNSCryptEncryptionParamsUpdater {
|
||||
}
|
||||
|
||||
pub async fn run(self) {
|
||||
let mut fut_interval = tokio::timer::Interval::new_interval(
|
||||
std::time::Duration::from_secs(u64::from(DNSCRYPT_CERTS_RENEWAL)),
|
||||
);
|
||||
let mut fut_interval = tokio::time::interval(std::time::Duration::from_secs(u64::from(
|
||||
DNSCRYPT_CERTS_RENEWAL,
|
||||
)));
|
||||
let fut = async move {
|
||||
while fut_interval.next().await.is_some() {
|
||||
loop {
|
||||
fut_interval.tick().await;
|
||||
self.update();
|
||||
debug!("New cert issued");
|
||||
}
|
||||
|
@ -13,13 +13,13 @@ use std::path::PathBuf;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[derive(Clone, Derivative)]
|
||||
#[derivative(Debug)]
|
||||
pub struct Globals {
|
||||
pub runtime: Arc<Runtime>,
|
||||
pub runtime_handle: Handle,
|
||||
pub state_file: PathBuf,
|
||||
pub dnscrypt_encryption_params_set: Arc<RwLock<Arc<Vec<Arc<DNSCryptEncryptionParams>>>>>,
|
||||
pub provider_name: String,
|
||||
|
76
src/main.rs
76
src/main.rs
@ -71,7 +71,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||
use tokio::prelude::*;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -238,21 +238,20 @@ async fn tls_proxy(
|
||||
SocketAddr::V6(_) => net2::TcpBuilder::new_v6()?.to_tcp_stream()?,
|
||||
},
|
||||
};
|
||||
let mut ext_socket =
|
||||
TcpStream::connect_std(std_socket, tls_upstream_addr, &Default::default()).await?;
|
||||
let mut ext_socket = TcpStream::connect_std(std_socket, tls_upstream_addr).await?;
|
||||
let (mut erh, mut ewh) = ext_socket.split();
|
||||
let (mut rh, mut wh) = client_connection.split();
|
||||
ewh.write_all(&binlen).await?;
|
||||
let fut_proxy_1 = rh.copy(&mut ewh);
|
||||
let fut_proxy_2 = erh.copy(&mut wh);
|
||||
let fut_proxy_1 = tokio::io::copy(&mut rh, &mut ewh);
|
||||
let fut_proxy_2 = tokio::io::copy(&mut erh, &mut wh);
|
||||
match join!(fut_proxy_1, fut_proxy_2) {
|
||||
(Ok(_), Ok(_)) => Ok(()),
|
||||
_ => bail!("TLS proxy error"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Result<(), Error> {
|
||||
let runtime = globals.runtime.clone();
|
||||
async fn tcp_acceptor(globals: Arc<Globals>, mut tcp_listener: TcpListener) -> Result<(), Error> {
|
||||
let runtime_handle = globals.runtime_handle.clone();
|
||||
let mut tcp_listener = tcp_listener.incoming();
|
||||
let timeout = globals.tcp_timeout;
|
||||
let concurrent_connections = globals.tcp_concurrent_connections.clone();
|
||||
@ -301,8 +300,8 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
|
||||
Ok(())
|
||||
};
|
||||
let fut_abort = rx;
|
||||
let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout);
|
||||
runtime.spawn(fut_all.map(move |_| {
|
||||
let fut_all = tokio::time::timeout(timeout, future::select(fut.boxed(), fut_abort));
|
||||
runtime_handle.spawn(fut_all.map(move |_| {
|
||||
let _count = concurrent_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
#[cfg(feature = "metrics")]
|
||||
varz.inflight_tcp_queries
|
||||
@ -317,7 +316,7 @@ async fn udp_acceptor(
|
||||
globals: Arc<Globals>,
|
||||
net_udp_socket: std::net::UdpSocket,
|
||||
) -> Result<(), Error> {
|
||||
let runtime = globals.runtime.clone();
|
||||
let runtime_handle = globals.runtime_handle.clone();
|
||||
let mut tokio_udp_socket = UdpSocket::try_from(net_udp_socket.try_clone()?)?;
|
||||
let timeout = globals.udp_timeout;
|
||||
let concurrent_connections = globals.udp_concurrent_connections.clone();
|
||||
@ -356,8 +355,8 @@ async fn udp_acceptor(
|
||||
let concurrent_connections = concurrent_connections.clone();
|
||||
let fut = handle_client_query(globals, client_ctx, packet);
|
||||
let fut_abort = rx;
|
||||
let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout);
|
||||
runtime.spawn(fut_all.map(move |_| {
|
||||
let fut_all = tokio::time::timeout(timeout, future::select(fut.boxed(), fut_abort));
|
||||
runtime_handle.spawn(fut_all.map(move |_| {
|
||||
let _count = concurrent_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
#[cfg(feature = "metrics")]
|
||||
varz.inflight_udp_queries
|
||||
@ -368,22 +367,27 @@ async fn udp_acceptor(
|
||||
|
||||
async fn start(
|
||||
globals: Arc<Globals>,
|
||||
runtime: Arc<Runtime>,
|
||||
listeners: Vec<(TcpListener, std::net::UdpSocket)>,
|
||||
runtime_handle: Handle,
|
||||
listeners: Vec<(std::net::TcpListener, std::net::UdpSocket)>,
|
||||
) -> Result<(), Error> {
|
||||
for listener in listeners {
|
||||
runtime.spawn(tcp_acceptor(globals.clone(), listener.0).map(|_| {}));
|
||||
runtime.spawn(udp_acceptor(globals.clone(), listener.1).map(|_| {}));
|
||||
let tcp_listener_str = format!("{:?}", listener.0);
|
||||
let tokio_tcp_listener = match TcpListener::from_std(listener.0) {
|
||||
Ok(tcp_listener) => tcp_listener,
|
||||
Err(e) => bail!("{}/TCP: {}", tcp_listener_str, e),
|
||||
};
|
||||
runtime_handle.spawn(tcp_acceptor(globals.clone(), tokio_tcp_listener).map(|_| {}));
|
||||
runtime_handle.spawn(udp_acceptor(globals.clone(), listener.1).map(|_| {}));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bind_listeners(
|
||||
listen_addrs: &[SocketAddr],
|
||||
) -> Result<Vec<(TcpListener, std::net::UdpSocket)>, Error> {
|
||||
) -> Result<Vec<(std::net::TcpListener, std::net::UdpSocket)>, Error> {
|
||||
let mut sockets = Vec::with_capacity(listen_addrs.len());
|
||||
for listen_addr in listen_addrs {
|
||||
let std_socket = match listen_addr {
|
||||
let tcp_listener = match listen_addr {
|
||||
SocketAddr::V4(_) => net2::TcpBuilder::new_v4()?
|
||||
.reuse_address(true)?
|
||||
.bind(&listen_addr)?
|
||||
@ -394,10 +398,6 @@ fn bind_listeners(
|
||||
.bind(&listen_addr)?
|
||||
.listen(1024)?,
|
||||
};
|
||||
let tcp_listener = match TcpListener::from_std(std_socket, &Default::default()) {
|
||||
Ok(tcp_listener) => tcp_listener,
|
||||
Err(e) => bail!("{}/TCP: {}", listen_addr, e),
|
||||
};
|
||||
let std_socket = match listen_addr {
|
||||
SocketAddr::V4(_) => net2::UdpBuilder::new_v4()?
|
||||
.reuse_address(true)?
|
||||
@ -495,8 +495,10 @@ fn main() -> Result<(), Error> {
|
||||
let external_addr = config.external_addr.map(|addr| SocketAddr::new(addr, 0));
|
||||
|
||||
let mut runtime_builder = tokio::runtime::Builder::new();
|
||||
runtime_builder.name_prefix("encrypted-dns-");
|
||||
let runtime = Arc::new(runtime_builder.build()?);
|
||||
runtime_builder.enable_all();
|
||||
runtime_builder.threaded_scheduler();
|
||||
runtime_builder.thread_name("encrypted-dns-");
|
||||
let mut runtime = runtime_builder.build()?;
|
||||
|
||||
let listen_addrs: Vec<_> = config.listen_addrs.iter().map(|x| x.local).collect();
|
||||
let listeners = bind_listeners(&listen_addrs)
|
||||
@ -627,9 +629,9 @@ fn main() -> Result<(), Error> {
|
||||
anonymized_dns.blacklisted_ips,
|
||||
),
|
||||
};
|
||||
|
||||
let runtime_handle = runtime.handle();
|
||||
let globals = Arc::new(Globals {
|
||||
runtime: runtime.clone(),
|
||||
runtime_handle: runtime_handle.clone(),
|
||||
state_file: state_file.to_path_buf(),
|
||||
dnscrypt_encryption_params_set: Arc::new(RwLock::new(Arc::new(
|
||||
dnscrypt_encryption_params_set,
|
||||
@ -671,18 +673,22 @@ fn main() -> Result<(), Error> {
|
||||
#[cfg(feature = "metrics")]
|
||||
{
|
||||
if let Some(metrics_config) = config.metrics {
|
||||
runtime.spawn(
|
||||
metrics::prometheus_service(globals.varz.clone(), metrics_config, runtime.clone())
|
||||
.map_err(|e| {
|
||||
error!("Unable to start the metrics service: [{}]", e);
|
||||
std::process::exit(1);
|
||||
})
|
||||
.map(|_| ()),
|
||||
runtime_handle.spawn(
|
||||
metrics::prometheus_service(
|
||||
globals.varz.clone(),
|
||||
metrics_config,
|
||||
runtime_handle.clone(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!("Unable to start the metrics service: [{}]", e);
|
||||
std::process::exit(1);
|
||||
})
|
||||
.map(|_| ()),
|
||||
);
|
||||
}
|
||||
}
|
||||
runtime.spawn(
|
||||
start(globals, runtime.clone(), listeners)
|
||||
runtime_handle.spawn(
|
||||
start(globals, runtime_handle.clone(), listeners)
|
||||
.map_err(|e| {
|
||||
error!("Unable to start the service: [{}]", e);
|
||||
std::process::exit(1);
|
||||
|
@ -14,8 +14,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::prelude::*;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::runtime::Handle;
|
||||
|
||||
const METRICS_CONNECTION_TIMEOUT_SECS: u64 = 10;
|
||||
const METRICS_MAX_CONCURRENT_CONNECTIONS: u32 = 2;
|
||||
@ -50,7 +49,7 @@ async fn handle_client_connection(
|
||||
pub async fn prometheus_service(
|
||||
varz: Varz,
|
||||
metrics_config: MetricsConfig,
|
||||
runtime: Arc<Runtime>,
|
||||
runtime_handle: Handle,
|
||||
) -> Result<(), Error> {
|
||||
let path = Arc::new(metrics_config.path);
|
||||
let std_socket = match metrics_config.listen_addr {
|
||||
@ -64,7 +63,7 @@ pub async fn prometheus_service(
|
||||
.bind(&metrics_config.listen_addr)?
|
||||
.listen(1024)?,
|
||||
};
|
||||
let mut stream = TcpListener::from_std(std_socket, &Default::default())?;
|
||||
let mut stream = TcpListener::from_std(std_socket)?;
|
||||
let concurrent_connections = Arc::new(AtomicU32::new(0));
|
||||
loop {
|
||||
let (client, _client_addr) = stream.accept().await?;
|
||||
@ -80,14 +79,14 @@ pub async fn prometheus_service(
|
||||
service_fn(move |req| handle_client_connection(req, varz.clone(), path.clone()));
|
||||
let connection = Http::new().serve_connection(client, service);
|
||||
let concurrent_connections = concurrent_connections.clone();
|
||||
runtime.spawn(
|
||||
connection
|
||||
.timeout(std::time::Duration::from_secs(
|
||||
METRICS_CONNECTION_TIMEOUT_SECS,
|
||||
))
|
||||
.map(move |_| {
|
||||
concurrent_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
}),
|
||||
runtime_handle.spawn(
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_secs(METRICS_CONNECTION_TIMEOUT_SECS),
|
||||
connection,
|
||||
)
|
||||
.map(move |_| {
|
||||
concurrent_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
}),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
|
@ -29,7 +29,7 @@ pub async fn resolve_udp(
|
||||
))?,
|
||||
},
|
||||
};
|
||||
let mut ext_socket = UdpSocket::from_std(std_socket, &Default::default())?;
|
||||
let mut ext_socket = UdpSocket::from_std(std_socket)?;
|
||||
ext_socket.connect(&globals.upstream_addr).await?;
|
||||
dns::set_edns_max_payload_size(&mut packet, DNS_MAX_PACKET_SIZE as u16)?;
|
||||
let mut response;
|
||||
@ -38,9 +38,7 @@ pub async fn resolve_udp(
|
||||
ext_socket.send(&packet).await?;
|
||||
response = vec![0u8; DNS_MAX_PACKET_SIZE];
|
||||
dns::set_rcode_servfail(&mut response);
|
||||
let fut = ext_socket
|
||||
.recv_from(&mut response[..])
|
||||
.timeout(timeout_if_cached);
|
||||
let fut = tokio::time::timeout(timeout_if_cached, ext_socket.recv_from(&mut response[..]));
|
||||
match fut.await {
|
||||
Ok(Ok((response_len, response_addr))) => {
|
||||
response.truncate(response_len);
|
||||
@ -78,8 +76,7 @@ pub async fn resolve_tcp(
|
||||
SocketAddr::V6(_) => net2::TcpBuilder::new_v6()?.to_tcp_stream()?,
|
||||
},
|
||||
};
|
||||
let mut ext_socket =
|
||||
TcpStream::connect_std(std_socket, &globals.upstream_addr, &Default::default()).await?;
|
||||
let mut ext_socket = TcpStream::connect_std(std_socket, &globals.upstream_addr).await?;
|
||||
ext_socket.set_nodelay(true)?;
|
||||
let mut binlen = [0u8, 0];
|
||||
BigEndian::write_u16(&mut binlen[..], packet.len() as u16);
|
||||
|
Loading…
Reference in New Issue
Block a user