From 5afc1f1a6abd894793ec1fdf07d34f88715f67c9 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Sun, 6 Oct 2019 21:04:40 +0200 Subject: [PATCH] Ignore casing for caching --- Cargo.toml | 4 +-- src/dns.rs | 94 +++++++++++++++++++++++++++++++++++++++++-------- src/resolver.rs | 12 ++++--- 3 files changed, 90 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 79b5ba9..6af7fb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ privdrop = "0.3.3" rand = "0.7.2" serde = "1.0.101" serde_derive = "1.0.101" -serde-big-array = "0.1.5" +serde-big-array = "0.2.0" siphasher = "0.3.1" tokio = "=0.2.0-alpha.6" tokio-net = "=0.2.0-alpha.6" @@ -39,7 +39,7 @@ toml = "0.5.3" [dependencies.hyper] optional = true -version = "0.13.0-alpha.3" +version = "0.13.0-alpha.4" default_features = false [dependencies.prometheus] diff --git a/src/dns.rs b/src/dns.rs index 515d446..f568c08 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -136,27 +136,16 @@ pub fn is_truncated(packet: &[u8]) -> bool { } pub fn qname(packet: &[u8]) -> Result, Error> { - assert!(std::usize::MAX > 0xffff); + debug_assert!(std::usize::MAX > 0xffff); + debug_assert!(DNS_MAX_HOSTNAME_SIZE > 0xff); ensure!(qdcount(packet) == 1, "Unexpected query count"); let packet_len = packet.len(); let mut offset = DNS_HEADER_SIZE; let mut qname = Vec::with_capacity(DNS_MAX_HOSTNAME_SIZE); - let mut indirections = 0; loop { ensure!(offset < packet_len, "Short packet"); match packet[offset] as usize { - label_len if label_len & 0xc0 == 0xc0 => { - ensure!(packet_len - offset > 1, "Short packet"); - let new_offset = (BigEndian::read_u16(&packet[offset..]) & 0x3fff) as usize; - indirections += 1; - ensure!( - new_offset >= DNS_HEADER_SIZE - && new_offset != offset - && indirections < DNS_MAX_INDIRECTIONS, - "Too many indirections" - ); - offset = new_offset; - } + label_len if label_len & 0xc0 == 0xc0 => bail!("Indirections"), 0 => { if qname.is_empty() { qname.push(b'.') @@ -182,6 +171,83 @@ pub fn qname(packet: &[u8]) -> Result, Error> { Ok(qname) } +pub fn normalize_qname(packet: &mut [u8]) -> Result<(), Error> { + debug_assert!(std::usize::MAX > 0xffff); + debug_assert!(DNS_MAX_HOSTNAME_SIZE > 0xff); + ensure!(qdcount(packet) == 1, "Unexpected query count"); + let packet_len = packet.len(); + let mut offset = DNS_HEADER_SIZE; + loop { + ensure!(offset < packet_len, "Short packet"); + match packet[offset] as usize { + label_len if label_len & 0xc0 == 0xc0 => bail!("Indirections"), + 0 => { + break; + } + label_len => { + ensure!(packet_len - offset > 1, "Short packet"); + offset += 1; + ensure!(packet_len - offset > label_len, "Short packet"); + ensure!( + offset - DNS_HEADER_SIZE < DNS_MAX_HOSTNAME_SIZE - label_len, + "Name too long" + ); + packet[offset..offset + label_len] + .iter_mut() + .for_each(|x| *x = x.to_ascii_lowercase()); + offset += label_len; + } + } + } + Ok(()) +} + +pub fn recase_qname(packet: &mut [u8], qname: &[u8]) -> Result<(), Error> { + debug_assert!(std::usize::MAX > 0xffff); + ensure!(qdcount(packet) == 1, "Unexpected query count"); + let packet_len = packet.len(); + let qname_len = qname.len(); + let mut offset = DNS_HEADER_SIZE; + let mut qname_offset = 0; + loop { + ensure!(offset < packet_len, "Short packet"); + match packet[offset] as usize { + label_len if label_len & 0xc0 == 0xc0 => bail!("Indirections"), + 0 => { + ensure!( + (qname_len == 1 && qname[0] == b'.') || qname_offset == qname_len, + "Unterminated reference qname" + ); + break; + } + label_len => { + ensure!(packet_len - offset > 1, "Short packet"); + ensure!(qname_offset < qname_len, "Short reference qname"); + offset += 1; + if qname_offset != 0 { + ensure!(qname[qname_offset] == b'.', "Non-matching reference qname"); + qname_offset += 1; + } + ensure!(packet_len - offset > label_len, "Short packet"); + ensure!( + qname_len - qname_offset >= label_len, + "Short reference qname" + ); + packet[offset..offset + label_len] + .iter_mut() + .zip(&qname[qname_offset..qname_offset + label_len]) + .for_each(|(a, b)| { + debug_assert!(a.eq_ignore_ascii_case(b)); + *a = *b + }); + offset += label_len; + qname_offset += label_len; + } + } + } + Ok(()) +} + fn skip_name(packet: &[u8], offset: usize) -> Result { let packet_len = packet.len(); ensure!(offset < packet_len - 1, "Short packet"); diff --git a/src/resolver.rs b/src/resolver.rs index 38fad10..598f49e 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -37,7 +37,7 @@ pub async fn resolve_udp( if response_addr == globals.upstream_addr && response_len >= DNS_HEADER_SIZE && dns::tid(&response) == tid - && packet_qname == dns::qname(&response)?.as_slice() + && packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice()) { break; } @@ -84,7 +84,7 @@ pub async fn resolve_tcp( ext_socket.read_exact(&mut response).await?; ensure!(dns::tid(&response) == tid, "Unexpected transaction ID"); ensure!( - packet_qname == dns::qname(&response)?.as_slice(), + packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice()), "Unexpected query name in the response" ); Ok(response) @@ -135,6 +135,7 @@ pub async fn resolve( globals.cache.lock().insert(packet_hash, cached_response); } dns::set_tid(&mut response, original_tid); + dns::recase_qname(&mut response, &packet_qname)?; #[cfg(feature = "metrics")] globals .varz @@ -157,6 +158,7 @@ pub async fn get_cached_response_or_resolve( } let original_tid = dns::tid(&packet); dns::set_tid(&mut packet, 0); + dns::normalize_qname(&mut packet)?; let mut hasher = globals.hasher; hasher.write(&packet); let packet_hash = hasher.finish128().as_u128(); @@ -172,12 +174,14 @@ pub async fn get_cached_response_or_resolve( let cached_response = match cached_response { None => None, Some(mut cached_response) => { - cached_response.set_tid(original_tid); if !cached_response.has_expired() { trace!("Cached"); #[cfg(feature = "metrics")] globals.varz.client_queries_cached.inc(); - return Ok(cached_response.into_response()); + cached_response.set_tid(original_tid); + let mut response = cached_response.into_response(); + dns::recase_qname(&mut response, &packet_qname)?; + return Ok(response); } trace!("Expired"); #[cfg(feature = "metrics")]