diff --git a/Cargo.toml b/Cargo.toml index af18385..13d376e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ futures-preview = { version = "=0.3.0-alpha.18", features = ["compat", "async-aw jemallocator = "0.3.2" libsodium-sys="0.2.4" log = "0.4.8" -parking_lot = "0.9.0" +rand = "0.7.0" tokio = "=0.2.0-alpha.4" [profile.release] diff --git a/src/dns.rs b/src/dns.rs index 288e30a..e83ee79 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -3,14 +3,15 @@ use crate::errors::*; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -const DNS_MAX_HOSTNAME_LEN: usize = 256; +pub const DNS_MAX_HOSTNAME_LEN: usize = 256; +pub const DNS_HEADER_SIZE: usize = 12; +pub const DNS_OFFSET_FLAGS: usize = 2; +pub const DNS_MAX_PACKET_SIZE: usize = 8192; + const DNS_MAX_INDIRECTIONS: usize = 16; -const DNS_HEADER_SIZE: usize = 12; -const DNS_OFFSET_FLAGS: usize = 2; const DNS_FLAGS_TC: u16 = 2u16 << 8; const DNS_FLAGS_QR: u16 = 128u16 << 8; const DNS_FLAGS_RA: u16 = 128; -const DNS_MAX_PACKET_SIZE: usize = 65_533; const DNS_OFFSET_QUESTION: usize = DNS_HEADER_SIZE; const DNS_TYPE_OPT: u16 = 41; const DNS_TYPE_TXT: u16 = 16; @@ -52,6 +53,16 @@ fn arcount_inc(packet: &mut [u8]) -> Result<(), Error> { Ok(()) } +#[inline] +pub fn tid(packet: &[u8]) -> u16 { + BigEndian::read_u16(&packet[0..]) +} + +#[inline] +pub fn set_tid(packet: &mut [u8], tid: u16) { + BigEndian::write_u16(&mut packet[0..], tid); +} + #[inline] pub fn authoritative_response(packet: &mut [u8]) { let current_flags = BigEndian::read_u16(&packet[DNS_OFFSET_FLAGS..]); @@ -248,12 +259,11 @@ pub fn set_edns_max_payload_size(packet: &mut Vec, max_payload_size: u16) -> let mut offset = skip_name(packet, DNS_OFFSET_QUESTION)?; assert!(offset > DNS_OFFSET_QUESTION); - ensure!(packet_len - offset <= 4, "Short packet"); + ensure!(packet_len - offset >= 4, "Short packet"); offset += 4; let (ancount, nscount, arcount) = (ancount(packet), nscount(packet), arcount(packet)); offset = traverse_rrs(packet, offset, ancount + nscount, |_offset| Ok(()))?; let mut edns_payload_set = false; - traverse_rrs_mut(packet, offset, arcount, |packet, offset| { let qtype = BigEndian::read_u16(&packet[offset..]); if qtype == DNS_TYPE_OPT { diff --git a/src/globals.rs b/src/globals.rs index de06219..cd26a49 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -3,6 +3,7 @@ use crate::dnscrypt_certs::*; use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use tokio::runtime::Runtime; #[derive(Debug)] @@ -12,4 +13,8 @@ pub struct Globals { pub dnscrypt_certs: Vec, pub provider_name: String, pub listen_addr: SocketAddr, + pub external_addr: SocketAddr, + pub upstream_addr: SocketAddr, + pub udp_timeout: Duration, + pub tcp_timeout: Duration, } diff --git a/src/main.rs b/src/main.rs index a4c4fe5..b222f4c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,18 +28,19 @@ use dnscrypt_certs::*; use errors::*; use globals::*; -use byteorder::{BigEndian, ByteOrder}; +use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; use clap::Arg; use dnsstamps::{InformalProperty, WithInformalProperty}; use failure::{bail, ensure}; use futures::prelude::*; use futures::{FutureExt, StreamExt}; -use parking_lot::RwLock; +use rand::prelude::*; use std::convert::TryFrom; use std::mem; use std::net::SocketAddr; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::sync::Arc; +use std::time::Duration; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::prelude::*; use tokio::runtime::{current_thread::Handle, Runtime}; @@ -49,14 +50,12 @@ const DNSCRYPT_QUERY_MAX_SIZE: usize = 512; #[derive(Debug)] struct UdpClientCtx { - udp_socket_fd: RawFd, - packet: Vec, + net_udp_socket: std::net::UdpSocket, client_addr: SocketAddr, } #[derive(Debug)] struct TcpClientCtx { - packet: Vec, client_connection: TcpStream, } @@ -66,41 +65,65 @@ enum ClientCtx { Tcp(TcpClientCtx), } -async fn respond_to_query(client_ctx: ClientCtx) -> Result<(), Error> { +async fn respond_to_query(client_ctx: ClientCtx, packet: Vec) -> Result<(), Error> { match client_ctx { ClientCtx::Udp(client_ctx) => { - let packet = client_ctx.packet; - let udp_socket = unsafe { std::net::UdpSocket::from_raw_fd(client_ctx.udp_socket_fd) }; - let _ = udp_socket.send_to(&packet, client_ctx.client_addr); - mem::forget(udp_socket); + let net_udp_socket = client_ctx.net_udp_socket; + net_udp_socket.send_to(&packet, client_ctx.client_addr)?; + } + ClientCtx::Tcp(client_ctx) => { + let packet_len = packet.len(); + ensure!(packet_len <= DNS_MAX_PACKET_SIZE, "Packet too large"); + let mut client_connection = client_ctx.client_connection; + let mut binlen = [0u8, 0]; + BigEndian::write_u16(&mut binlen[..], packet_len as u16); + client_connection.write_all(&binlen).await?; + client_connection.write_all(&packet).await?; } - ClientCtx::Tcp(client_ctx) => {} } Ok(()) } -async fn handle_client_query(client_ctx: ClientCtx) -> Result<(), Error> { - // if let Some(synth_packet) = - // serve_certificates(&packet, &globals.provider_name, &globals.dnscrypt_certs)? - // { - // let _ = udp_socket.send_to(&synth_packet, client_addr).await; - // continue; - // } - // truncate(&mut packet); - // let _ = udp_socket.send_to(&packet, client_addr).await; - - dbg!(&client_ctx); - respond_to_query(client_ctx).await +async fn handle_client_query( + globals: Arc, + client_ctx: ClientCtx, + mut packet: Vec, +) -> Result<(), Error> { + if let Some(synth_packet) = + serve_certificates(&packet, &globals.provider_name, &globals.dnscrypt_certs)? + { + return respond_to_query(client_ctx, synth_packet).await; + } + set_edns_max_payload_size(&mut packet, DNS_MAX_PACKET_SIZE as u16)?; + let original_tid = dns::tid(&packet); + let tid = random(); + dns::set_tid(&mut packet, tid); + let mut ext_socket = UdpSocket::bind(&globals.external_addr).await?; + ext_socket.connect(&globals.upstream_addr).await?; + ext_socket.send(&packet).await.unwrap(); + let mut response = vec![0u8; DNS_MAX_PACKET_SIZE]; + let response_len = ext_socket.recv(&mut response[..]).await?; + ensure!(response_len > DNS_HEADER_SIZE, "Short packet"); + response.truncate(response_len); + ensure!(dns::tid(&response) == tid, "Unexpected transaction ID"); + ensure!( + dns::qname(&packet)? == dns::qname(&response)?, + "Unexpected query name in the response" + ); + dns::set_tid(&mut response, original_tid); + respond_to_query(client_ctx, response).await } async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Result<(), Error> { let runtime = globals.runtime.clone(); let mut tcp_listener = tcp_listener.incoming(); + let timeout = globals.tcp_timeout; while let Some(client) = tcp_listener.next().await { let mut client_connection: TcpStream = match client { Ok(client_connection) => client_connection, Err(e) => bail!(e), }; + let globals = globals.clone(); runtime.spawn( async { let mut binlen = [0u8, 0]; @@ -112,39 +135,46 @@ async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Resul ); let mut packet = vec![0u8; packet_len]; client_connection.read_exact(&mut packet).await?; - let client_ctx = ClientCtx::Tcp(TcpClientCtx { - packet, - client_connection, - }); - let _ = handle_client_query(client_ctx).await; + let client_ctx = ClientCtx::Tcp(TcpClientCtx { client_connection }); + let _ = handle_client_query(globals, client_ctx, packet).await; Ok(()) } + .timeout(timeout) .map(|_| ()), ); } Ok(()) } -async fn udp_acceptor(globals: Arc, mut udp_socket: UdpSocket) -> Result<(), Error> { +async fn udp_acceptor( + globals: Arc, + net_udp_socket: std::net::UdpSocket, +) -> Result<(), Error> { let runtime = globals.runtime.clone(); + let mut tokio_udp_socket = UdpSocket::try_from(net_udp_socket.try_clone()?)?; + let timeout = globals.udp_timeout; loop { let mut packet = vec![0u8; DNSCRYPT_QUERY_MAX_SIZE]; - let (packet_len, client_addr) = udp_socket.recv_from(&mut packet).await?; - let udp_socket_fd = udp_socket.as_raw_fd(); + let (packet_len, client_addr) = tokio_udp_socket.recv_from(&mut packet).await?; + let net_udp_socket = net_udp_socket.try_clone()?; packet.truncate(packet_len); let client_ctx = ClientCtx::Udp(UdpClientCtx { - udp_socket_fd, - packet, + net_udp_socket, client_addr, }); - runtime.spawn(async { handle_client_query(client_ctx).await }.map(|_| ())); + let globals = globals.clone(); + runtime.spawn( + async { handle_client_query(globals, client_ctx, packet).await } + .timeout(timeout) + .map(|_| ()), + ); } } async fn start(globals: Arc, runtime: Arc) -> Result<(), Error> { let socket_addr: SocketAddr = globals.listen_addr; let tcp_listener = TcpListener::bind(&socket_addr).await?; - let udp_socket = UdpSocket::bind(&socket_addr).await?; + let udp_socket = std::net::UdpSocket::bind(&socket_addr)?; runtime.spawn(tcp_acceptor(globals.clone(), tcp_listener).map(|_| {})); runtime.spawn(udp_acceptor(globals.clone(), udp_socket).map(|_| {})); Ok(()) @@ -171,17 +201,46 @@ fn main() -> Result<(), Error> { .required(true) .help("Provider name"), ) + .arg( + Arg::with_name("upstream-addr") + .value_name("upstream-addr") + .takes_value(true) + .default_value("9.9.9.9:53") + .required(true) + .help("Address and port of the upstream server"), + ) + .arg( + Arg::with_name("external-addr") + .value_name("external-addr") + .takes_value(true) + .default_value("0.0.0.0:0") + .required(true) + .help("Address and port to connect from"), + ) .get_matches(); + let listen_addr = matches .value_of("listen-addr") .unwrap() .to_ascii_lowercase(); + let provider_name = match matches.value_of("provider-name").unwrap() { provider_name if provider_name.starts_with("2.dnscrypt.") => provider_name.to_string(), provider_name => format!("2.dnscrypt.{}", provider_name), }; + let listen_addr_s = matches.value_of("listen-addr").unwrap(); let listen_addr: SocketAddr = listen_addr_s.parse()?; + + let upstream_addr_s = matches.value_of("upstream-addr").unwrap(); + let upstream_addr: SocketAddr = upstream_addr_s.parse()?; + + let external_addr_s = matches.value_of("external-addr").unwrap(); + let external_addr: SocketAddr = external_addr_s.parse()?; + + let udp_timeout = Duration::from_secs(10); + let tcp_timeout = Duration::from_secs(10); + let resolver_kp = SignKeyPair::new(); info!("Server address: {}", listen_addr); @@ -209,6 +268,10 @@ fn main() -> Result<(), Error> { dnscrypt_certs: vec![dnscrypt_cert], provider_name, listen_addr, + upstream_addr, + external_addr, + tcp_timeout, + udp_timeout, }); runtime.spawn(start(globals, runtime.clone()).map(|_| ())); runtime.block_on(future::pending::<()>());