From e1bd1f52dcfc6bb2d95b5dbad435336feef5f30c Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Mon, 9 Sep 2019 14:01:10 +0200 Subject: [PATCH] Recycle old connections --- Cargo.toml | 1 + src/globals.rs | 7 +++++++ src/main.rs | 43 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 13d376e..debc482 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ futures-preview = { version = "=0.3.0-alpha.18", features = ["compat", "async-aw jemallocator = "0.3.2" libsodium-sys="0.2.4" log = "0.4.8" +parking_lot = "0.9.0" rand = "0.7.0" tokio = "=0.2.0-alpha.4" diff --git a/src/globals.rs b/src/globals.rs index f62e76e..c29d9bf 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -1,11 +1,14 @@ use crate::crypto::*; use crate::dnscrypt_certs::*; +use parking_lot::Mutex; +use std::collections::vec_deque::VecDeque; use std::net::SocketAddr; use std::sync::atomic::AtomicU32; use std::sync::Arc; use std::time::Duration; use tokio::runtime::Runtime; +use tokio::sync::oneshot; #[derive(Debug)] pub struct Globals { @@ -20,4 +23,8 @@ pub struct Globals { pub tcp_timeout: Duration, pub udp_concurrent_connections: Arc, pub tcp_concurrent_connections: Arc, + pub udp_max_active_connections: u32, + pub tcp_max_active_connections: u32, + pub udp_active_connections: Arc>>>, + pub tcp_active_connections: Arc>>>, } diff --git a/src/main.rs b/src/main.rs index b167d70..47ebb50 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,8 +33,10 @@ use clap::Arg; use dnsstamps::{InformalProperty, WithInformalProperty}; use failure::{bail, ensure}; use futures::prelude::*; -use futures::{FutureExt, StreamExt}; +use futures::{pin_mut, FutureExt, StreamExt}; +use parking_lot::Mutex; use rand::prelude::*; +use std::collections::vec_deque::VecDeque; use std::convert::TryFrom; use std::mem; use std::net::SocketAddr; @@ -45,6 +47,7 @@ use std::time::Duration; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::prelude::*; use tokio::runtime::{current_thread::Handle, Runtime}; +use tokio::sync::oneshot; const DNSCRYPT_QUERY_MIN_SIZE: usize = 12; const DNSCRYPT_QUERY_MAX_SIZE: usize = 512; @@ -155,11 +158,21 @@ async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Resul let mut tcp_listener = tcp_listener.incoming(); let timeout = globals.tcp_timeout; let concurrent_connections = globals.tcp_concurrent_connections.clone(); + let active_connections = globals.tcp_active_connections.clone(); while let Some(client) = tcp_listener.next().await { let mut client_connection: TcpStream = match client { Ok(client_connection) => client_connection, Err(e) => bail!(e), }; + let (tx, rx) = oneshot::channel::<()>(); + { + let mut active_connections = active_connections.lock(); + if active_connections.len() >= globals.tcp_max_active_connections as _ { + let tx_oldest = active_connections.pop_back().unwrap(); + let _ = tx_oldest.send(()); + } + active_connections.push_front(tx); + } concurrent_connections.fetch_add(1, Ordering::Relaxed); client_connection.set_nodelay(true)?; let globals = globals.clone(); @@ -178,7 +191,9 @@ async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Resul let _ = handle_client_query(globals, client_ctx, packet).await; Ok(()) }; - runtime.spawn(fut.timeout(timeout).map(move |_| { + let fut_abort = rx; + let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout); + runtime.spawn(fut_all.map(move |_| { concurrent_connections.fetch_sub(1, Ordering::Relaxed); })); } @@ -193,6 +208,7 @@ async fn udp_acceptor( let mut tokio_udp_socket = UdpSocket::try_from(net_udp_socket.try_clone()?)?; let timeout = globals.udp_timeout; let concurrent_connections = globals.udp_concurrent_connections.clone(); + let active_connections = globals.udp_active_connections.clone(); loop { let mut packet = vec![0u8; DNSCRYPT_QUERY_MAX_SIZE]; let (packet_len, client_addr) = tokio_udp_socket.recv_from(&mut packet).await?; @@ -202,11 +218,22 @@ async fn udp_acceptor( net_udp_socket, client_addr, }); + let (tx, rx) = oneshot::channel::<()>(); + { + let mut active_connections = active_connections.lock(); + if active_connections.len() >= globals.tcp_max_active_connections as _ { + let tx_oldest = active_connections.pop_back().unwrap(); + let _ = tx_oldest.send(()); + } + active_connections.push_front(tx); + } concurrent_connections.fetch_add(1, Ordering::Relaxed); let globals = globals.clone(); let concurrent_connections = concurrent_connections.clone(); let fut = handle_client_query(globals, client_ctx, packet); - runtime.spawn(fut.timeout(timeout).map(move |_| { + let fut_abort = rx; + let fut_all = future::select(fut.boxed(), fut_abort).timeout(timeout); + runtime.spawn(fut_all.map(move |_| { concurrent_connections.fetch_sub(1, Ordering::Relaxed); })); } @@ -303,6 +330,8 @@ fn main() -> Result<(), Error> { let dnscrypt_cert = DNSCryptCert::new(&resolver_kp); let runtime = Arc::new(Runtime::new()?); + let udp_max_active_connections = 1000; + let tcp_max_active_connections = 100; let globals = Arc::new(Globals { runtime: runtime.clone(), resolver_kp, @@ -315,6 +344,14 @@ fn main() -> Result<(), Error> { udp_timeout, udp_concurrent_connections: Arc::new(AtomicU32::new(0)), tcp_concurrent_connections: Arc::new(AtomicU32::new(0)), + udp_max_active_connections, + tcp_max_active_connections, + udp_active_connections: Arc::new(Mutex::new(VecDeque::with_capacity( + udp_max_active_connections as _, + ))), + tcp_active_connections: Arc::new(Mutex::new(VecDeque::with_capacity( + tcp_max_active_connections as _, + ))), }); runtime.spawn(start(globals, runtime.clone()).map(|_| ())); runtime.block_on(future::pending::<()>());