diff --git a/src/crypto.rs b/src/crypto.rs index f7ee616..ef8d921 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -91,7 +91,7 @@ impl SignKeyPair { } } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct CryptSK([u8; crypto_box_curve25519xchacha20poly1305_SECRETKEYBYTES as usize]); impl CryptSK { @@ -108,7 +108,7 @@ impl CryptSK { } } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct CryptPK([u8; crypto_box_curve25519xchacha20poly1305_PUBLICKEYBYTES as usize]); impl CryptPK { @@ -125,7 +125,7 @@ impl CryptPK { } } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct CryptKeyPair { pub sk: CryptSK, pub pk: CryptPK, diff --git a/src/dns.rs b/src/dns.rs index 206b328..a428a18 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -2,6 +2,7 @@ use crate::dnscrypt_certs::*; use crate::errors::*; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; +use std::sync::Arc; pub const DNS_MAX_HOSTNAME_SIZE: usize = 256; pub const DNS_HEADER_SIZE: usize = 12; @@ -293,7 +294,7 @@ pub fn set_edns_max_payload_size(packet: &mut Vec, max_payload_size: u16) -> pub fn serve_certificates<'t>( client_packet: &[u8], expected_qname: &str, - dnscrypt_encryption_params_set: impl IntoIterator, + dnscrypt_encryption_params_set: impl IntoIterator>, ) -> Result>, Error> { ensure!(client_packet.len() >= DNS_HEADER_SIZE, "Short packet"); ensure!(qdcount(&client_packet) == 1, "No question"); @@ -322,7 +323,7 @@ pub fn serve_certificates<'t>( packet.write_u16::(0xc000 + DNS_HEADER_SIZE as u16)?; packet.write_u16::(DNS_TYPE_TXT)?; packet.write_u16::(DNS_CLASS_INET)?; - packet.write_u32::(28800)?; + packet.write_u32::(DNSCRYPT_CERTS_RENEWAL)?; packet.write_u16::(1 + cert_bin.len() as u16)?; packet.write_u8(cert_bin.len() as u8)?; packet.extend_from_slice(&cert_bin[..]); diff --git a/src/dnscrypt.rs b/src/dnscrypt.rs index 87cc782..4a1c9de 100644 --- a/src/dnscrypt.rs +++ b/src/dnscrypt.rs @@ -5,6 +5,7 @@ use crate::errors::*; use libsodium_sys::*; use rand::prelude::*; +use std::sync::Arc; pub const DNSCRYPT_FULL_NONCE_SIZE: usize = crypto_box_curve25519xchacha20poly1305_NONCEBYTES as usize; @@ -40,7 +41,7 @@ pub const DNSCRYPT_TCP_RESPONSE_MAX_SIZE: usize = pub fn decrypt( wrapped_packet: &[u8], - dnscrypt_encryption_params_set: &[DNSCryptEncryptionParams], + dnscrypt_encryption_params_set: &[Arc], ) -> Result<(SharedKey, [u8; DNSCRYPT_FULL_NONCE_SIZE as usize], Vec), Error> { ensure!( wrapped_packet.len() diff --git a/src/dnscrypt_certs.rs b/src/dnscrypt_certs.rs index 24dde9d..c924b71 100644 --- a/src/dnscrypt_certs.rs +++ b/src/dnscrypt_certs.rs @@ -1,11 +1,17 @@ use crate::crypto::*; +use crate::globals::*; use byteorder::{BigEndian, ByteOrder}; use coarsetime::{Clock, Duration}; +use parking_lot::RwLock; use std::mem; use std::slice; +use std::sync::Arc; use std::time::SystemTime; +pub const DNSCRYPT_CERTS_TTL: u32 = 86400; +pub const DNSCRYPT_CERTS_RENEWAL: u32 = 28800; + fn now() -> u32 { SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) @@ -13,7 +19,7 @@ fn now() -> u32 { .as_secs() as u32 } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] #[repr(C, packed)] pub struct DNSCryptCertInner { resolver_pk: [u8; 32], @@ -30,7 +36,7 @@ impl DNSCryptCertInner { } #[derive(Derivative)] -#[derivative(Debug, Default)] +#[derivative(Debug, Default, Clone)] #[repr(C, packed)] pub struct DNSCryptCert { cert_magic: [u8; 4], @@ -44,7 +50,7 @@ pub struct DNSCryptCert { impl DNSCryptCert { pub fn new(provider_kp: &SignKeyPair, resolver_kp: &CryptKeyPair) -> Self { let ts_start = now(); - let ts_end = ts_start + 86400; + let ts_end = ts_start + DNSCRYPT_CERTS_TTL; let mut dnscrypt_cert = DNSCryptCert::default(); @@ -113,3 +119,32 @@ impl DNSCryptEncryptionParams { &self.resolver_kp } } + +pub struct DNSCryptEncryptionParamsUpdater { + globals: Arc, +} + +impl DNSCryptEncryptionParamsUpdater { + pub fn new(globals: Arc) -> Self { + DNSCryptEncryptionParamsUpdater { globals } + } + + pub async fn run(self) { + let mut fut_interval = tokio::timer::Interval::new_interval( + std::time::Duration::from_secs(u64::from(DNSCRYPT_CERTS_RENEWAL)), + ); + let fut = async { + loop { + fut_interval.next().await; + let new_params = DNSCryptEncryptionParams::new(&self.globals.provider_kp); + debug!("New cert issued"); + let mut params_set = self.globals.dnscrypt_encryption_params_set.write(); + if params_set.len() >= (DNSCRYPT_CERTS_TTL / DNSCRYPT_CERTS_RENEWAL) as usize { + params_set.swap_remove(0); + } + params_set.push(Arc::new(new_params)); + } + }; + fut.await + } +} diff --git a/src/globals.rs b/src/globals.rs index 0c75ba1..78ccae5 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -1,6 +1,7 @@ +use crate::crypto::*; use crate::dnscrypt_certs::*; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use std::collections::vec_deque::VecDeque; use std::net::SocketAddr; use std::sync::atomic::AtomicU32; @@ -12,8 +13,9 @@ use tokio::sync::oneshot; #[derive(Debug)] pub struct Globals { pub runtime: Arc, - pub dnscrypt_encryption_params_set: Vec, + pub dnscrypt_encryption_params_set: Arc>>>, pub provider_name: String, + pub provider_kp: SignKeyPair, pub listen_addrs: Vec, pub external_addr: SocketAddr, pub upstream_addr: SocketAddr, diff --git a/src/main.rs b/src/main.rs index ae33494..309e274 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,6 +42,7 @@ use failure::{bail, ensure}; use futures::join; use futures::prelude::*; use parking_lot::Mutex; +use parking_lot::RwLock; use privdrop::PrivDrop; use rand::prelude::*; use std::collections::vec_deque::VecDeque; @@ -148,15 +149,19 @@ async fn handle_client_query( encrypted_packet: Vec, ) -> Result<(), Error> { let original_packet_size = encrypted_packet.len(); + let mut dnscrypt_encryption_params_set = vec![]; + for params in &*globals.dnscrypt_encryption_params_set.read() { + dnscrypt_encryption_params_set.push((*params).clone()) + } let (shared_key, nonce, mut packet) = - match dnscrypt::decrypt(&encrypted_packet, &globals.dnscrypt_encryption_params_set) { + match dnscrypt::decrypt(&encrypted_packet, &dnscrypt_encryption_params_set) { Ok(x) => x, Err(_) => { let packet = encrypted_packet; if let Some(synth_packet) = serve_certificates( &packet, &globals.provider_name, - &globals.dnscrypt_encryption_params_set, + &dnscrypt_encryption_params_set, )? { return respond_to_query( client_ctx, @@ -463,8 +468,11 @@ fn main() -> Result<(), Error> { } let globals = Arc::new(Globals { runtime: runtime.clone(), - dnscrypt_encryption_params_set: vec![dnscrypt_encryption_params], + dnscrypt_encryption_params_set: Arc::new(RwLock::new(vec![Arc::new( + dnscrypt_encryption_params, + )])), provider_name, + provider_kp, listen_addrs: config.listen_addrs, upstream_addr: config.upstream_addr, tls_upstream_addr: config.tls.upstream_addr, @@ -482,6 +490,8 @@ fn main() -> Result<(), Error> { config.tcp_max_active_connections as _, ))), }); + let updater = DNSCryptEncryptionParamsUpdater::new(globals.clone()); + runtime.spawn(updater.run()); runtime.spawn(start(globals, runtime.clone()).map(|_| ())); runtime.block_on(future::pending::<()>());