Ignore casing for caching

pull/12/head
Frank Denis 5 years ago
parent 8f03adbdad
commit 5afc1f1a6a

@ -31,7 +31,7 @@ privdrop = "0.3.3"
rand = "0.7.2"
serde = "1.0.101"
serde_derive = "1.0.101"
serde-big-array = "0.1.5"
serde-big-array = "0.2.0"
siphasher = "0.3.1"
tokio = "=0.2.0-alpha.6"
tokio-net = "=0.2.0-alpha.6"
@ -39,7 +39,7 @@ toml = "0.5.3"
[dependencies.hyper]
optional = true
version = "0.13.0-alpha.3"
version = "0.13.0-alpha.4"
default_features = false
[dependencies.prometheus]

@ -136,27 +136,16 @@ pub fn is_truncated(packet: &[u8]) -> bool {
}
pub fn qname(packet: &[u8]) -> Result<Vec<u8>, Error> {
assert!(std::usize::MAX > 0xffff);
debug_assert!(std::usize::MAX > 0xffff);
debug_assert!(DNS_MAX_HOSTNAME_SIZE > 0xff);
ensure!(qdcount(packet) == 1, "Unexpected query count");
let packet_len = packet.len();
let mut offset = DNS_HEADER_SIZE;
let mut qname = Vec::with_capacity(DNS_MAX_HOSTNAME_SIZE);
let mut indirections = 0;
loop {
ensure!(offset < packet_len, "Short packet");
match packet[offset] as usize {
label_len if label_len & 0xc0 == 0xc0 => {
ensure!(packet_len - offset > 1, "Short packet");
let new_offset = (BigEndian::read_u16(&packet[offset..]) & 0x3fff) as usize;
indirections += 1;
ensure!(
new_offset >= DNS_HEADER_SIZE
&& new_offset != offset
&& indirections < DNS_MAX_INDIRECTIONS,
"Too many indirections"
);
offset = new_offset;
}
label_len if label_len & 0xc0 == 0xc0 => bail!("Indirections"),
0 => {
if qname.is_empty() {
qname.push(b'.')
@ -182,6 +171,83 @@ pub fn qname(packet: &[u8]) -> Result<Vec<u8>, Error> {
Ok(qname)
}
pub fn normalize_qname(packet: &mut [u8]) -> Result<(), Error> {
debug_assert!(std::usize::MAX > 0xffff);
debug_assert!(DNS_MAX_HOSTNAME_SIZE > 0xff);
ensure!(qdcount(packet) == 1, "Unexpected query count");
let packet_len = packet.len();
let mut offset = DNS_HEADER_SIZE;
loop {
ensure!(offset < packet_len, "Short packet");
match packet[offset] as usize {
label_len if label_len & 0xc0 == 0xc0 => bail!("Indirections"),
0 => {
break;
}
label_len => {
ensure!(packet_len - offset > 1, "Short packet");
offset += 1;
ensure!(packet_len - offset > label_len, "Short packet");
ensure!(
offset - DNS_HEADER_SIZE < DNS_MAX_HOSTNAME_SIZE - label_len,
"Name too long"
);
packet[offset..offset + label_len]
.iter_mut()
.for_each(|x| *x = x.to_ascii_lowercase());
offset += label_len;
}
}
}
Ok(())
}
pub fn recase_qname(packet: &mut [u8], qname: &[u8]) -> Result<(), Error> {
debug_assert!(std::usize::MAX > 0xffff);
ensure!(qdcount(packet) == 1, "Unexpected query count");
let packet_len = packet.len();
let qname_len = qname.len();
let mut offset = DNS_HEADER_SIZE;
let mut qname_offset = 0;
loop {
ensure!(offset < packet_len, "Short packet");
match packet[offset] as usize {
label_len if label_len & 0xc0 == 0xc0 => bail!("Indirections"),
0 => {
ensure!(
(qname_len == 1 && qname[0] == b'.') || qname_offset == qname_len,
"Unterminated reference qname"
);
break;
}
label_len => {
ensure!(packet_len - offset > 1, "Short packet");
ensure!(qname_offset < qname_len, "Short reference qname");
offset += 1;
if qname_offset != 0 {
ensure!(qname[qname_offset] == b'.', "Non-matching reference qname");
qname_offset += 1;
}
ensure!(packet_len - offset > label_len, "Short packet");
ensure!(
qname_len - qname_offset >= label_len,
"Short reference qname"
);
packet[offset..offset + label_len]
.iter_mut()
.zip(&qname[qname_offset..qname_offset + label_len])
.for_each(|(a, b)| {
debug_assert!(a.eq_ignore_ascii_case(b));
*a = *b
});
offset += label_len;
qname_offset += label_len;
}
}
}
Ok(())
}
fn skip_name(packet: &[u8], offset: usize) -> Result<usize, Error> {
let packet_len = packet.len();
ensure!(offset < packet_len - 1, "Short packet");

@ -37,7 +37,7 @@ pub async fn resolve_udp(
if response_addr == globals.upstream_addr
&& response_len >= DNS_HEADER_SIZE
&& dns::tid(&response) == tid
&& packet_qname == dns::qname(&response)?.as_slice()
&& packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice())
{
break;
}
@ -84,7 +84,7 @@ pub async fn resolve_tcp(
ext_socket.read_exact(&mut response).await?;
ensure!(dns::tid(&response) == tid, "Unexpected transaction ID");
ensure!(
packet_qname == dns::qname(&response)?.as_slice(),
packet_qname.eq_ignore_ascii_case(dns::qname(&response)?.as_slice()),
"Unexpected query name in the response"
);
Ok(response)
@ -135,6 +135,7 @@ pub async fn resolve(
globals.cache.lock().insert(packet_hash, cached_response);
}
dns::set_tid(&mut response, original_tid);
dns::recase_qname(&mut response, &packet_qname)?;
#[cfg(feature = "metrics")]
globals
.varz
@ -157,6 +158,7 @@ pub async fn get_cached_response_or_resolve(
}
let original_tid = dns::tid(&packet);
dns::set_tid(&mut packet, 0);
dns::normalize_qname(&mut packet)?;
let mut hasher = globals.hasher;
hasher.write(&packet);
let packet_hash = hasher.finish128().as_u128();
@ -172,12 +174,14 @@ pub async fn get_cached_response_or_resolve(
let cached_response = match cached_response {
None => None,
Some(mut cached_response) => {
cached_response.set_tid(original_tid);
if !cached_response.has_expired() {
trace!("Cached");
#[cfg(feature = "metrics")]
globals.varz.client_queries_cached.inc();
return Ok(cached_response.into_response());
cached_response.set_tid(original_tid);
let mut response = cached_response.into_response();
dns::recase_qname(&mut response, &packet_qname)?;
return Ok(response);
}
trace!("Expired");
#[cfg(feature = "metrics")]

Loading…
Cancel
Save