diff --git a/src/dnscrypt.rs b/src/dnscrypt.rs index 3821091..097d09c 100644 --- a/src/dnscrypt.rs +++ b/src/dnscrypt.rs @@ -4,17 +4,26 @@ use crate::dnscrypt_certs::*; use crate::errors::*; use libsodium_sys::*; +use rand::prelude::*; use std::ffi::CStr; use std::ptr; pub const DNSCRYPT_CLIENT_MAGIC_SIZE: usize = 8; pub const DNSCRYPT_CLIENT_PK_SIZE: usize = 32; -pub const DNSCRYPT_CLIENT_NONCE_SIZE: usize = 12; +pub const DNSCRYPT_CLIENT_NONCE_SIZE: usize = + crypto_box_curve25519xchacha20poly1305_HALFNONCEBYTES as usize; pub fn decrypt( wrapped_packet: &[u8], dnscrypt_encryption_params_set: &[DNSCryptEncryptionParams], -) -> Result, Error> { +) -> Result< + ( + SharedKey, + [u8; crypto_box_curve25519xchacha20poly1305_NONCEBYTES as usize], + Vec, + ), + Error, +> { ensure!( wrapped_packet.len() >= DNSCRYPT_CLIENT_MAGIC_SIZE @@ -37,11 +46,20 @@ pub fn decrypt( .find(|p| p.client_magic() == client_magic) .ok_or_else(|| format_err!("Client magic not found"))?; - let mut nonce = vec![0u8; crypto_box_curve25519xchacha20poly1305_NONCEBYTES as usize]; - &mut nonce[..crypto_box_curve25519xchacha20poly1305_HALFNONCEBYTES] - .copy_from_slice(client_nonce); + let mut nonce = [0u8; crypto_box_curve25519xchacha20poly1305_NONCEBYTES as usize]; + &mut nonce[..DNSCRYPT_CLIENT_NONCE_SIZE].copy_from_slice(client_nonce); let resolver_kp = dnscrypt_encryption_params.resolver_kp(); - let shared_secret = resolver_kp.compute_shared_key(client_pk)?; - let packet = shared_secret.decrypt(&nonce, encrypted_packet)?; - Ok(packet) + let shared_key = resolver_kp.compute_shared_key(client_pk)?; + let packet = shared_key.decrypt(&nonce, encrypted_packet)?; + rand::thread_rng().fill_bytes(&mut nonce[DNSCRYPT_CLIENT_NONCE_SIZE..]); + + Ok((shared_key, nonce, packet)) +} + +pub fn encrypt( + packet: &[u8], + shared_key: &SharedKey, + nonce: &[u8; crypto_box_curve25519xchacha20poly1305_NONCEBYTES as usize], +) { + // } diff --git a/src/main.rs b/src/main.rs index 10b1c39..386b599 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,7 +72,11 @@ enum ClientCtx { Tcp(TcpClientCtx), } -async fn respond_to_query(client_ctx: ClientCtx, packet: Vec) -> Result<(), Error> { +async fn respond_to_query( + client_ctx: ClientCtx, + packet: Vec, + shared_key: Option, +) -> Result<(), Error> { ensure!(dns::is_response(&packet), "Packet is not a response"); match client_ctx { ClientCtx::Udp(client_ctx) => { @@ -98,21 +102,21 @@ async fn handle_client_query( client_ctx: ClientCtx, encrypted_packet: Vec, ) -> Result<(), Error> { - let packet = dnscrypt::decrypt(&encrypted_packet, &globals.dnscrypt_encryption_params_set); - let mut packet = match packet { - Ok(packet) => packet, - Err(_) => { - let packet = encrypted_packet; - if let Some(synth_packet) = serve_certificates( - &packet, - &globals.provider_name, - &globals.dnscrypt_encryption_params_set, - )? { - return respond_to_query(client_ctx, synth_packet).await; + let (shared_key, nonce, mut packet) = + match dnscrypt::decrypt(&encrypted_packet, &globals.dnscrypt_encryption_params_set) { + Ok(x) => x, + Err(_) => { + let packet = encrypted_packet; + if let Some(synth_packet) = serve_certificates( + &packet, + &globals.provider_name, + &globals.dnscrypt_encryption_params_set, + )? { + return respond_to_query(client_ctx, synth_packet, None).await; + } + bail!("Unencrypted query"); } - bail!("Unencrypted query"); - } - }; + }; ensure!(packet.len() >= DNS_HEADER_SIZE, "Short packet"); ensure!(qdcount(&packet) == 1, "No question"); ensure!( @@ -171,7 +175,7 @@ async fn handle_client_query( ); } dns::set_tid(&mut response, original_tid); - respond_to_query(client_ctx, response).await + respond_to_query(client_ctx, response, Some(shared_key)).await } async fn tcp_acceptor(globals: Arc, tcp_listener: TcpListener) -> Result<(), Error> {