First DNS compile. Took DNS code from https://github.com/EmilHernvall/hermes.

pull/2/head
Revertron 3 years ago
parent b4ae51088d
commit 4b5e5112da

@ -20,11 +20,13 @@ num-traits = "0.2"
bincode = "1.2.0"
groestl = "0.8.0"
base64 = "0.11.0"
chrono = "0.4.9"
chrono = { version = "0.4.13", features = ["serde"] }
rand = "0.7.2"
sqlite = "0.25.3"
uuid = { version = "0.8.2", features = ["serde", "v4"] }
mio = { version = "0.7", features = ["os-poll", "net"] }
# for DNS from hermes
derive_more = "0.99.9"
[build-dependencies]
winres = "0.1"

@ -0,0 +1,257 @@
//! contains the data store for local zones
use std::collections::{BTreeMap, BTreeSet};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard};
use derive_more::{Display, From, Error};
use crate::dns::buffer::{PacketBuffer, StreamPacketBuffer, VectorPacketBuffer};
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl};
#[derive(Debug, Display, From, Error)]
pub enum AuthorityError {
Buffer(crate::dns::buffer::BufferError),
Protocol(crate::dns::protocol::ProtocolError),
Io(std::io::Error),
PoisonedLock,
}
type Result<T> = std::result::Result<T, AuthorityError>;
#[derive(Clone, Debug, Default)]
pub struct Zone {
pub domain: String,
pub m_name: String,
pub r_name: String,
pub serial: u32,
pub refresh: u32,
pub retry: u32,
pub expire: u32,
pub minimum: u32,
pub records: BTreeSet<DnsRecord>,
}
impl Zone {
pub fn new(domain: String, m_name: String, r_name: String) -> Zone {
Zone {
domain: domain,
m_name: m_name,
r_name: r_name,
serial: 0,
refresh: 0,
retry: 0,
expire: 0,
minimum: 0,
records: BTreeSet::new(),
}
}
pub fn add_record(&mut self, rec: &DnsRecord) -> bool {
self.records.insert(rec.clone())
}
pub fn delete_record(&mut self, rec: &DnsRecord) -> bool {
self.records.remove(rec)
}
}
#[derive(Default)]
pub struct Zones {
zones: BTreeMap<String, Zone>,
}
impl<'a> Zones {
pub fn new() -> Zones {
Zones {
zones: BTreeMap::new(),
}
}
pub fn load(&mut self) -> Result<()> {
let zones_dir = Path::new("zones").read_dir()?;
for wrapped_filename in zones_dir {
let filename = match wrapped_filename {
Ok(x) => x,
Err(_) => continue,
};
let mut zone_file = match File::open(filename.path()) {
Ok(x) => x,
Err(_) => continue,
};
let mut buffer = StreamPacketBuffer::new(&mut zone_file);
let mut zone = Zone::new(String::new(), String::new(), String::new());
buffer.read_qname(&mut zone.domain)?;
buffer.read_qname(&mut zone.m_name)?;
buffer.read_qname(&mut zone.r_name)?;
zone.serial = buffer.read_u32()?;
zone.refresh = buffer.read_u32()?;
zone.retry = buffer.read_u32()?;
zone.expire = buffer.read_u32()?;
zone.minimum = buffer.read_u32()?;
let record_count = buffer.read_u32()?;
for _ in 0..record_count {
let rr = DnsRecord::read(&mut buffer)?;
zone.add_record(&rr);
}
println!("Loaded zone {} with {} records", zone.domain, record_count);
self.zones.insert(zone.domain.clone(), zone);
}
Ok(())
}
pub fn save(&mut self) -> Result<()> {
let zones_dir = Path::new("zones");
for zone in self.zones.values() {
let filename = zones_dir.join(Path::new(&zone.domain));
let mut zone_file = match File::create(&filename) {
Ok(x) => x,
Err(_) => {
println!("Failed to save file {:?}", filename);
continue;
}
};
let mut buffer = VectorPacketBuffer::new();
let _ = buffer.write_qname(&zone.domain);
let _ = buffer.write_qname(&zone.m_name);
let _ = buffer.write_qname(&zone.r_name);
let _ = buffer.write_u32(zone.serial);
let _ = buffer.write_u32(zone.refresh);
let _ = buffer.write_u32(zone.retry);
let _ = buffer.write_u32(zone.expire);
let _ = buffer.write_u32(zone.minimum);
let _ = buffer.write_u32(zone.records.len() as u32);
for rec in &zone.records {
let _ = rec.write(&mut buffer);
}
let _ = zone_file.write(&buffer.buffer[0..buffer.pos]);
}
Ok(())
}
pub fn zones(&self) -> Vec<&Zone> {
self.zones.values().collect()
}
pub fn add_zone(&mut self, zone: Zone) {
self.zones.insert(zone.domain.clone(), zone);
}
pub fn get_zone(&'a self, domain: &str) -> Option<&'a Zone> {
self.zones.get(domain)
}
pub fn get_zone_mut(&'a mut self, domain: &str) -> Option<&'a mut Zone> {
self.zones.get_mut(domain)
}
}
#[derive(Default)]
pub struct Authority {
zones: RwLock<Zones>,
}
impl Authority {
pub fn new() -> Authority {
Authority {
zones: RwLock::new(Zones::new()),
}
}
pub fn load(&self) -> Result<()> {
let mut zones = self
.zones
.write()
.map_err(|_| AuthorityError::PoisonedLock)?;
zones.load()?;
Ok(())
}
pub fn query(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
let zones = match self.zones.read().ok() {
Some(x) => x,
None => return None,
};
let mut best_match = None;
for zone in zones.zones() {
if !qname.ends_with(&zone.domain) {
continue;
}
if let Some((len, _)) = best_match {
if len < zone.domain.len() {
best_match = Some((zone.domain.len(), zone));
}
} else {
best_match = Some((zone.domain.len(), zone));
}
}
let zone = match best_match {
Some((_, zone)) => zone,
None => return None,
};
let mut packet = DnsPacket::new();
packet.header.authoritative_answer = true;
for rec in &zone.records {
let domain = match rec.get_domain() {
Some(x) => x,
None => continue,
};
if &domain != qname {
continue;
}
let rtype = rec.get_querytype();
if qtype == rtype || (qtype == QueryType::A && rtype == QueryType::CNAME) {
packet.answers.push(rec.clone());
}
}
if packet.answers.is_empty() {
packet.header.rescode = ResultCode::NXDOMAIN;
packet.authorities.push(DnsRecord::SOA {
domain: zone.domain.clone(),
m_name: zone.m_name.clone(),
r_name: zone.r_name.clone(),
serial: zone.serial,
refresh: zone.refresh,
retry: zone.retry,
expire: zone.expire,
minimum: zone.minimum,
ttl: TransientTtl(zone.minimum),
});
}
Some(packet)
}
pub fn read(&self) -> LockResult<RwLockReadGuard<'_, Zones>> {
self.zones.read()
}
pub fn write(&self) -> LockResult<RwLockWriteGuard<'_, Zones>> {
self.zones.write()
}
}

@ -0,0 +1,487 @@
//! buffers for use when writing and reading dns packets
use std::collections::BTreeMap;
use std::io::Read;
use derive_more::{Display, Error, From};
#[derive(Debug, Display, From, Error)]
pub enum BufferError {
Io(std::io::Error),
EndOfBuffer,
}
type Result<T> = std::result::Result<T, BufferError>;
pub trait PacketBuffer {
fn read(&mut self) -> Result<u8>;
fn get(&mut self, pos: usize) -> Result<u8>;
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]>;
fn write(&mut self, val: u8) -> Result<()>;
fn set(&mut self, pos: usize, val: u8) -> Result<()>;
fn pos(&self) -> usize;
fn seek(&mut self, pos: usize) -> Result<()>;
fn step(&mut self, steps: usize) -> Result<()>;
fn find_label(&self, label: &str) -> Option<usize>;
fn save_label(&mut self, label: &str, pos: usize);
fn write_u8(&mut self, val: u8) -> Result<()> {
self.write(val)?;
Ok(())
}
fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
self.set(pos, (val >> 8) as u8)?;
self.set(pos + 1, (val & 0xFF) as u8)?;
Ok(())
}
fn write_u16(&mut self, val: u16) -> Result<()> {
self.write((val >> 8) as u8)?;
self.write((val & 0xFF) as u8)?;
Ok(())
}
fn write_u32(&mut self, val: u32) -> Result<()> {
self.write(((val >> 24) & 0xFF) as u8)?;
self.write(((val >> 16) & 0xFF) as u8)?;
self.write(((val >> 8) & 0xFF) as u8)?;
self.write(((val >> 0) & 0xFF) as u8)?;
Ok(())
}
fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>();
let mut jump_performed = false;
for (i, label) in split_str.iter().enumerate() {
let search_lbl = split_str[i..split_str.len()].join(".");
if let Some(prev_pos) = self.find_label(&search_lbl) {
let jump_inst = (prev_pos as u16) | 0xC000;
self.write_u16(jump_inst)?;
jump_performed = true;
break;
}
let pos = self.pos();
self.save_label(&search_lbl, pos);
let len = label.len();
self.write_u8(len as u8)?;
for b in label.as_bytes() {
self.write_u8(*b)?;
}
}
if !jump_performed {
self.write_u8(0)?;
}
Ok(())
}
fn read_u16(&mut self) -> Result<u16> {
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
Ok(res)
}
fn read_u32(&mut self) -> Result<u32> {
let res = ((self.read()? as u32) << 24)
| ((self.read()? as u32) << 16)
| ((self.read()? as u32) << 8)
| ((self.read()? as u32) << 0);
Ok(res)
}
fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
let mut pos = self.pos();
let mut jumped = false;
let mut delim = "";
loop {
let len = self.get(pos)?;
// A two byte sequence, where the two highest bits of the first byte is
// set, represents a offset relative to the start of the buffer. We
// handle this by jumping to the offset, setting a flag to indicate
// that we shouldn't update the shared buffer position once done.
if (len & 0xC0) > 0 {
// When a jump is performed, we only modify the shared buffer
// position once, and avoid making the change later on.
if !jumped {
self.seek(pos + 2)?;
}
let b2 = self.get(pos + 1)? as u16;
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
pos = offset as usize;
jumped = true;
continue;
}
pos += 1;
// Names are terminated by an empty label of length 0
if len == 0 {
break;
}
outstr.push_str(delim);
let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = ".";
pos += len as usize;
}
if !jumped {
self.seek(pos)?;
}
Ok(())
}
}
#[derive(Default)]
pub struct VectorPacketBuffer {
pub buffer: Vec<u8>,
pub pos: usize,
pub label_lookup: BTreeMap<String, usize>,
}
impl VectorPacketBuffer {
pub fn new() -> VectorPacketBuffer {
VectorPacketBuffer {
buffer: Vec::new(),
pos: 0,
label_lookup: BTreeMap::new(),
}
}
}
impl PacketBuffer for VectorPacketBuffer {
fn find_label(&self, label: &str) -> Option<usize> {
self.label_lookup.get(label).cloned()
}
fn save_label(&mut self, label: &str, pos: usize) {
self.label_lookup.insert(label.to_string(), pos);
}
fn read(&mut self) -> Result<u8> {
let res = self.buffer[self.pos];
self.pos += 1;
Ok(res)
}
fn get(&mut self, pos: usize) -> Result<u8> {
Ok(self.buffer[pos])
}
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
Ok(&self.buffer[start..start + len as usize])
}
fn write(&mut self, val: u8) -> Result<()> {
self.buffer.push(val);
self.pos += 1;
Ok(())
}
fn set(&mut self, pos: usize, val: u8) -> Result<()> {
self.buffer[pos] = val;
Ok(())
}
fn pos(&self) -> usize {
self.pos
}
fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos;
Ok(())
}
fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps;
Ok(())
}
}
pub struct StreamPacketBuffer<'a, T>
where
T: Read,
{
pub stream: &'a mut T,
pub buffer: Vec<u8>,
pub pos: usize,
}
impl<'a, T> StreamPacketBuffer<'a, T>
where
T: Read + 'a,
{
pub fn new(stream: &'a mut T) -> StreamPacketBuffer<'_, T> {
StreamPacketBuffer {
stream: stream,
buffer: Vec::new(),
pos: 0,
}
}
}
impl<'a, T> PacketBuffer for StreamPacketBuffer<'a, T>
where
T: Read + 'a,
{
fn find_label(&self, _: &str) -> Option<usize> {
None
}
fn save_label(&mut self, _: &str, _: usize) {
unimplemented!();
}
fn read(&mut self) -> Result<u8> {
while self.pos >= self.buffer.len() {
let mut local_buffer = [0; 1];
self.stream.read(&mut local_buffer)?;
self.buffer.push(local_buffer[0]);
}
let res = self.buffer[self.pos];
self.pos += 1;
Ok(res)
}
fn get(&mut self, pos: usize) -> Result<u8> {
while pos >= self.buffer.len() {
let mut local_buffer = [0; 1];
self.stream.read(&mut local_buffer)?;
self.buffer.push(local_buffer[0]);
}
Ok(self.buffer[pos])
}
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
while start + len > self.buffer.len() {
let mut local_buffer = [0; 1];
self.stream.read(&mut local_buffer)?;
self.buffer.push(local_buffer[0]);
}
Ok(&self.buffer[start..start + len as usize])
}
fn write(&mut self, _: u8) -> Result<()> {
unimplemented!();
}
fn set(&mut self, _: usize, _: u8) -> Result<()> {
unimplemented!();
}
fn pos(&self) -> usize {
self.pos
}
fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos;
Ok(())
}
fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps;
Ok(())
}
}
pub struct BytePacketBuffer {
pub buf: [u8; 512],
pub pos: usize,
}
impl BytePacketBuffer {
pub fn new() -> BytePacketBuffer {
BytePacketBuffer {
buf: [0; 512],
pos: 0,
}
}
}
impl Default for BytePacketBuffer {
fn default() -> Self {
BytePacketBuffer::new()
}
}
impl PacketBuffer for BytePacketBuffer {
fn find_label(&self, _: &str) -> Option<usize> {
None
}
fn save_label(&mut self, _: &str, _: usize) {}
fn read(&mut self) -> Result<u8> {
if self.pos >= 512 {
return Err(BufferError::EndOfBuffer);
}
let res = self.buf[self.pos];
self.pos += 1;
Ok(res)
}
fn get(&mut self, pos: usize) -> Result<u8> {
if pos >= 512 {
return Err(BufferError::EndOfBuffer);
}
Ok(self.buf[pos])
}
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
if start + len >= 512 {
return Err(BufferError::EndOfBuffer);
}
Ok(&self.buf[start..start + len as usize])
}
fn write(&mut self, val: u8) -> Result<()> {
if self.pos >= 512 {
return Err(BufferError::EndOfBuffer);
}
self.buf[self.pos] = val;
self.pos += 1;
Ok(())
}
fn set(&mut self, pos: usize, val: u8) -> Result<()> {
self.buf[pos] = val;
Ok(())
}
fn pos(&self) -> usize {
self.pos
}
fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos;
Ok(())
}
fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qname() {
let mut buffer = VectorPacketBuffer::new();
let instr1 = "a.google.com".to_string();
let instr2 = "b.google.com".to_string();
// First write the standard string
match buffer.write_qname(&instr1) {
Ok(_) => {}
Err(_) => panic!(),
}
// Then we set up a slight variation with relies on a jump back to the data of
// the first name
let crafted_data = [0x01, b'b' as u8, 0xC0, 0x02];
for b in &crafted_data {
match buffer.write_u8(*b) {
Ok(_) => {}
Err(_) => panic!(),
}
}
// Reset the buffer position for reading
buffer.pos = 0;
// Read the standard name
let mut outstr1 = String::new();
match buffer.read_qname(&mut outstr1) {
Ok(_) => {}
Err(_) => panic!(),
}
assert_eq!(instr1, outstr1);
// Read the name with a jump
let mut outstr2 = String::new();
match buffer.read_qname(&mut outstr2) {
Ok(_) => {}
Err(_) => panic!(),
}
assert_eq!(instr2, outstr2);
// Make sure we're now at the end of the buffer
assert_eq!(buffer.pos, buffer.buffer.len());
}
#[test]
fn test_write_qname() {
let mut buffer = VectorPacketBuffer::new();
match buffer.write_qname(&"ns1.google.com".to_string()) {
Ok(_) => {}
Err(_) => panic!(),
}
match buffer.write_qname(&"ns2.google.com".to_string()) {
Ok(_) => {}
Err(_) => panic!(),
}
assert_eq!(22, buffer.pos());
match buffer.seek(0) {
Ok(_) => {}
Err(_) => panic!(),
}
let mut str1 = String::new();
match buffer.read_qname(&mut str1) {
Ok(_) => {}
Err(_) => panic!(),
}
assert_eq!("ns1.google.com", str1);
let mut str2 = String::new();
match buffer.read_qname(&mut str2) {
Ok(_) => {}
Err(_) => panic!(),
}
assert_eq!("ns2.google.com", str2);
}
}

@ -0,0 +1,462 @@
//! a threadsafe cache for DNS information
extern crate serde;
use std::clone::Clone;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use chrono::*;
use derive_more::{Display, Error, From};
use serde::{Deserialize, Serialize};
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
#[derive(Debug, Display, From, Error)]
pub enum CacheError {
Io(std::io::Error),
PoisonedLock,
}
type Result<T> = std::result::Result<T, CacheError>;
pub enum CacheState {
PositiveCache,
NegativeCache,
NotCached,
}
#[derive(Clone, Eq, Debug, Serialize, Deserialize)]
pub struct RecordEntry {
pub record: DnsRecord,
pub timestamp: DateTime<Local>,
}
impl PartialEq<RecordEntry> for RecordEntry {
fn eq(&self, other: &RecordEntry) -> bool {
self.record == other.record
}
}
impl Hash for RecordEntry {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.record.hash(state);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum RecordSet {
NoRecords {
qtype: QueryType,
ttl: u32,
timestamp: DateTime<Local>,
},
Records {
qtype: QueryType,
records: HashSet<RecordEntry>,
},
}
#[derive(Clone, Debug)]
pub struct DomainEntry {
pub domain: String,
pub record_types: HashMap<QueryType, RecordSet>,
pub hits: u32,
pub updates: u32,
}
impl DomainEntry {
pub fn new(domain: String) -> DomainEntry {
DomainEntry {
domain: domain,
record_types: HashMap::new(),
hits: 0,
updates: 0,
}
}
pub fn store_nxdomain(&mut self, qtype: QueryType, ttl: u32) {
self.updates += 1;
let new_set = RecordSet::NoRecords {
qtype: qtype,
ttl: ttl,
timestamp: Local::now(),
};
self.record_types.insert(qtype, new_set);
}
pub fn store_record(&mut self, rec: &DnsRecord) {
self.updates += 1;
let entry = RecordEntry {
record: rec.clone(),
timestamp: Local::now(),
};
if let Some(&mut RecordSet::Records {
ref mut records, ..
}) = self.record_types.get_mut(&rec.get_querytype())
{
if records.contains(&entry) {
records.remove(&entry);
}
records.insert(entry);
return;
}
let mut records = HashSet::new();
records.insert(entry);
let new_set = RecordSet::Records {
qtype: rec.get_querytype(),
records: records,
};
self.record_types.insert(rec.get_querytype(), new_set);
}
pub fn get_cache_state(&self, qtype: QueryType) -> CacheState {
match self.record_types.get(&qtype) {
Some(&RecordSet::Records { ref records, .. }) => {
let now = Local::now();
let mut valid_count = 0;
for entry in records {
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
let expires = entry.timestamp + ttl_offset;
if expires < now {
continue;
}
if entry.record.get_querytype() == qtype {
valid_count += 1;
}
}
if valid_count > 0 {
CacheState::PositiveCache
} else {
CacheState::NotCached
}
}
Some(&RecordSet::NoRecords { ttl, timestamp, .. }) => {
let now = Local::now();
let ttl_offset = Duration::seconds(ttl as i64);
let expires = timestamp + ttl_offset;
if expires < now {
CacheState::NotCached
} else {
CacheState::NegativeCache
}
}
None => CacheState::NotCached,
}
}
pub fn fill_queryresult(&self, qtype: QueryType, result_vec: &mut Vec<DnsRecord>) {
let now = Local::now();
let current_set = match self.record_types.get(&qtype) {
Some(x) => x,
None => return,
};
if let RecordSet::Records { ref records, .. } = *current_set {
for entry in records {
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
let expires = entry.timestamp + ttl_offset;
if expires < now {
continue;
}
if entry.record.get_querytype() == qtype {
result_vec.push(entry.record.clone());
}
}
}
}
}
#[derive(Default)]
pub struct Cache {
domain_entries: BTreeMap<String, Arc<DomainEntry>>,
}
impl Cache {
pub fn new() -> Cache {
Cache {
domain_entries: BTreeMap::new(),
}
}
fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState {
match self.domain_entries.get(qname) {
Some(x) => x.get_cache_state(qtype),
None => CacheState::NotCached,
}
}
fn fill_queryresult(
&mut self,
qname: &str,
qtype: QueryType,
result_vec: &mut Vec<DnsRecord>,
increment_stats: bool,
) {
if let Some(domain_entry) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
if increment_stats {
domain_entry.hits += 1
}
domain_entry.fill_queryresult(qtype, result_vec);
}
}
pub fn lookup(&mut self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
match self.get_cache_state(qname, qtype) {
CacheState::PositiveCache => {
let mut qr = DnsPacket::new();
self.fill_queryresult(qname, qtype, &mut qr.answers, true);
self.fill_queryresult(qname, QueryType::NS, &mut qr.authorities, false);
Some(qr)
}
CacheState::NegativeCache => {
let mut qr = DnsPacket::new();
qr.header.rescode = ResultCode::NXDOMAIN;
Some(qr)
}
CacheState::NotCached => None,
}
}
pub fn store(&mut self, records: &[DnsRecord]) {
for rec in records {
let domain = match rec.get_domain() {
Some(x) => x,
None => continue,
};
if let Some(ref mut rs) = self.domain_entries.get_mut(&domain).and_then(Arc::get_mut) {
rs.store_record(rec);
continue;
}
let mut rs = DomainEntry::new(domain.clone());
rs.store_record(rec);
self.domain_entries.insert(domain.clone(), Arc::new(rs));
}
}
pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) {
if let Some(ref mut rs) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
rs.store_nxdomain(qtype, ttl);
return;
}
let mut rs = DomainEntry::new(qname.to_string());
rs.store_nxdomain(qtype, ttl);
self.domain_entries.insert(qname.to_string(), Arc::new(rs));
}
}
#[derive(Default)]
pub struct SynchronizedCache {
pub cache: RwLock<Cache>,
}
impl SynchronizedCache {
pub fn new() -> SynchronizedCache {
SynchronizedCache {
cache: RwLock::new(Cache::new()),
}
}
pub fn list(&self) -> Result<Vec<Arc<DomainEntry>>> {
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
let mut list = Vec::new();
for rs in cache.domain_entries.values() {
list.push(rs.clone());
}
Ok(list)
}
pub fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
let mut cache = match self.cache.write() {
Ok(x) => x,
Err(_) => return None,
};
cache.lookup(qname, qtype)
}
pub fn store(&self, records: &[DnsRecord]) -> Result<()> {
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
cache.store(records);
Ok(())
}
pub fn store_nxdomain(&self, qname: &str, qtype: QueryType, ttl: u32) -> Result<()> {
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
cache.store_nxdomain(qname, qtype, ttl);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::protocol::{DnsRecord, QueryType, ResultCode, TransientTtl};
#[test]
fn test_cache() {
let mut cache = Cache::new();
// Verify that no data is returned when nothing is present
if cache.lookup("www.google.com", QueryType::A).is_some() {
panic!()
}
// Register a negative cache entry
cache.store_nxdomain("www.google.com", QueryType::A, 3600);
// Verify that we get a response, with the NXDOMAIN flag set
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode);
}
// Register a negative cache entry with no TTL
cache.store_nxdomain("www.yahoo.com", QueryType::A, 0);
// And check that no such result is actually returned, since it's expired
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
panic!()
}
// Now add some actual records
let mut records = Vec::new();
records.push(DnsRecord::A {
domain: "www.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
records.push(DnsRecord::A {
domain: "www.yahoo.com".to_string(),
addr: "127.0.0.2".parse().unwrap(),
ttl: TransientTtl(0),
});
records.push(DnsRecord::CNAME {
domain: "www.microsoft.com".to_string(),
host: "www.somecdn.com".to_string(),
ttl: TransientTtl(3600),
});
cache.store(&records);
// Test for successful lookup
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
assert_eq!(records[0], packet.answers[0]);
} else {
panic!();
}
// Test for failed lookup, since no CNAME's are known for this domain
if cache.lookup("www.google.com", QueryType::CNAME).is_some() {
panic!();
}
// Check for successful CNAME lookup
if let Some(packet) = cache.lookup("www.microsoft.com", QueryType::CNAME) {
assert_eq!(records[2], packet.answers[0]);
} else {
panic!();
}
// This lookup should fail, since it has expired due to the 0 second TTL
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
panic!();
}
let mut records2 = Vec::new();
records2.push(DnsRecord::A {
domain: "www.yahoo.com".to_string(),
addr: "127.0.0.2".parse().unwrap(),
ttl: TransientTtl(3600),
});
cache.store(&records2);
// And now it should succeed, since the record has been store
if !cache.lookup("www.yahoo.com", QueryType::A).is_some() {
panic!();
}
// Check stat counter behavior
assert_eq!(3, cache.domain_entries.len());
assert_eq!(
1,
cache
.domain_entries
.get(&"www.google.com".to_string())
.unwrap()
.hits
);
assert_eq!(
2,
cache
.domain_entries
.get(&"www.google.com".to_string())
.unwrap()
.updates
);
assert_eq!(
1,
cache
.domain_entries
.get(&"www.yahoo.com".to_string())
.unwrap()
.hits
);
assert_eq!(
3,
cache
.domain_entries
.get(&"www.yahoo.com".to_string())
.unwrap()
.updates
);
assert_eq!(
1,
cache
.domain_entries
.get(&"www.microsoft.com".to_string())
.unwrap()
.updates
);
assert_eq!(
1,
cache
.domain_entries
.get(&"www.microsoft.com".to_string())
.unwrap()
.hits
);
}
}

@ -0,0 +1,400 @@
//! client for sending DNS queries to other servers
use std::io::Write;
use std::marker::{Send, Sync};
use std::net::{TcpStream, UdpSocket};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{sleep, Builder};
use std::time::Duration as SleepDuration;
use chrono::*;
use derive_more::{Display, Error, From};
use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer};
use crate::dns::netutil::{read_packet_length, write_packet_length};
use crate::dns::protocol::{DnsPacket, DnsQuestion, QueryType};
#[derive(Debug, Display, From, Error)]
pub enum ClientError {
Protocol(crate::dns::protocol::ProtocolError),
Io(std::io::Error),
PoisonedLock,
LookupFailed,
TimeOut,
}
type Result<T> = std::result::Result<T, ClientError>;
pub trait DnsClient {
fn get_sent_count(&self) -> usize;
fn get_failed_count(&self) -> usize;
fn run(&self) -> Result<()>;
fn send_query(
&self,
qname: &str,
qtype: QueryType,
server: (&str, u16),
recursive: bool,
) -> Result<DnsPacket>;
}
/// The UDP client
///
/// This includes a fair bit of synchronization due to the stateless nature of UDP.
/// When many queries are sent in parallell, the response packets can come back
/// in any order. For that reason, we fire off replies on the sending thread, but
/// handle replies on a single thread. A channel is created for every response,
/// and the caller will block on the channel until the a response is received.
pub struct DnsNetworkClient {
total_sent: AtomicUsize,
total_failed: AtomicUsize,
/// Counter for assigning packet ids
seq: AtomicUsize,
/// The listener socket
socket: UdpSocket,
/// Queries in progress
pending_queries: Arc<Mutex<Vec<PendingQuery>>>,
}
/// A query in progress. This struct holds the `id` if the request, and a channel
/// endpoint for returning a response back to the thread from which the query
/// was posed.
struct PendingQuery {
seq: u16,
timestamp: DateTime<Local>,
tx: Sender<Option<DnsPacket>>,
}
unsafe impl Send for DnsNetworkClient {}
unsafe impl Sync for DnsNetworkClient {}
impl DnsNetworkClient {
pub fn new(port: u16) -> DnsNetworkClient {
DnsNetworkClient {
total_sent: AtomicUsize::new(0),
total_failed: AtomicUsize::new(0),
seq: AtomicUsize::new(0),
socket: UdpSocket::bind(("0.0.0.0", port)).unwrap(),
pending_queries: Arc::new(Mutex::new(Vec::new())),
}
}
/// Send a DNS query using TCP transport
///
/// This is much simpler than using UDP, since the kernel will take care of
/// packet ordering, connection state, timeouts etc.
pub fn send_tcp_query(
&self,
qname: &str,
qtype: QueryType,
server: (&str, u16),
recursive: bool,
) -> Result<DnsPacket> {
let _ = self.total_sent.fetch_add(1, Ordering::Release);
// Prepare request
let mut packet = DnsPacket::new();
packet.header.id = self.seq.fetch_add(1, Ordering::SeqCst) as u16;
if packet.header.id + 1 == 0xFFFF {
self.seq.compare_and_swap(0xFFFF, 0, Ordering::SeqCst);
}
packet.header.questions = 1;
packet.header.recursion_desired = recursive;
packet.questions.push(DnsQuestion::new(qname.into(), qtype));
// Send query
let mut req_buffer = BytePacketBuffer::new();
packet.write(&mut req_buffer, 0xFFFF)?;
let mut socket = TcpStream::connect(server)?;
write_packet_length(&mut socket, req_buffer.pos())?;
socket.write(&req_buffer.buf[0..req_buffer.pos])?;
socket.flush()?;
let _ = read_packet_length(&mut socket)?;
let mut stream_buffer = StreamPacketBuffer::new(&mut socket);
let packet = DnsPacket::from_buffer(&mut stream_buffer)?;
Ok(packet)
}
/// Send a DNS query using UDP transport
///
/// This will construct a query packet, and fire it off to the specified server.
/// The query is sent from the callee thread, but responses are read on a
/// worker thread, and returned to this thread through a channel. Thus this
/// method is thread safe, and can be used from any number of threads in
/// parallell.
pub fn send_udp_query(
&self,
qname: &str,
qtype: QueryType,
server: (&str, u16),
recursive: bool,
) -> Result<DnsPacket> {
let _ = self.total_sent.fetch_add(1, Ordering::Release);
// Prepare request
let mut packet = DnsPacket::new();
packet.header.id = self.seq.fetch_add(1, Ordering::SeqCst) as u16;
if packet.header.id + 1 == 0xFFFF {
self.seq.compare_and_swap(0xFFFF, 0, Ordering::SeqCst);
}
packet.header.questions = 1;
packet.header.recursion_desired = recursive;
packet
.questions
.push(DnsQuestion::new(qname.to_string(), qtype));
// Create a return channel, and add a `PendingQuery` to the list of lookups
// in progress
let (tx, rx) = channel();
{
let mut pending_queries = self
.pending_queries
.lock()
.map_err(|_| ClientError::PoisonedLock)?;
pending_queries.push(PendingQuery {
seq: packet.header.id,
timestamp: Local::now(),
tx: tx,
});
}
// Send query
let mut req_buffer = BytePacketBuffer::new();
packet.write(&mut req_buffer, 512)?;
self.socket
.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
// Wait for response
match rx.recv() {
Ok(Some(qr)) => Ok(qr),
Ok(None) => {
let _ = self.total_failed.fetch_add(1, Ordering::Release);
Err(ClientError::TimeOut)
}
Err(_) => {
let _ = self.total_failed.fetch_add(1, Ordering::Release);
Err(ClientError::LookupFailed)
}
}
}
}
impl DnsClient for DnsNetworkClient {
fn get_sent_count(&self) -> usize {
self.total_sent.load(Ordering::Acquire)
}
fn get_failed_count(&self) -> usize {
self.total_failed.load(Ordering::Acquire)
}
/// The run method launches a worker thread. Unless this thread is running, no
/// responses will ever be generated, and clients will just block indefinitely.
fn run(&self) -> Result<()> {
// Start the thread for handling incoming responses
{
let socket_copy = self.socket.try_clone()?;
let pending_queries_lock = self.pending_queries.clone();
Builder::new()
.name("DnsNetworkClient-worker-thread".into())
.spawn(move || {
loop {
// Read data into a buffer
let mut res_buffer = BytePacketBuffer::new();
match socket_copy.recv_from(&mut res_buffer.buf) {
Ok(_) => {}
Err(_) => {
continue;
}
}
// Construct a DnsPacket from buffer, skipping the packet if parsing
// failed
let packet = match DnsPacket::from_buffer(&mut res_buffer) {
Ok(packet) => packet,
Err(err) => {
println!(
"DnsNetworkClient failed to parse packet with error: {}",
err
);
continue;
}
};
// Acquire a lock on the pending_queries list, and search for a
// matching PendingQuery to which to deliver the response.
if let Ok(mut pending_queries) = pending_queries_lock.lock() {
let mut matched_query = None;
for (i, pending_query) in pending_queries.iter().enumerate() {
if pending_query.seq == packet.header.id {
// Matching query found, send the response
let _ = pending_query.tx.send(Some(packet.clone()));
// Mark this index for removal from list
matched_query = Some(i);
break;
}
}
if let Some(idx) = matched_query {
pending_queries.remove(idx);
} else {
println!("Discarding response for: {:?}", packet.questions[0]);
}
}
}
})?;
}
// Start the thread for timing out requests
{
let pending_queries_lock = self.pending_queries.clone();
Builder::new()
.name("DnsNetworkClient-timeout-thread".into())
.spawn(move || {
let timeout = Duration::seconds(1);
loop {
if let Ok(mut pending_queries) = pending_queries_lock.lock() {
let mut finished_queries = Vec::new();
for (i, pending_query) in pending_queries.iter().enumerate() {
let expires = pending_query.timestamp + timeout;
if expires < Local::now() {
let _ = pending_query.tx.send(None);
finished_queries.push(i);
}
}
// Remove `PendingQuery` objects from the list, in reverse order
for idx in finished_queries.iter().rev() {
pending_queries.remove(*idx);
}
}
sleep(SleepDuration::from_millis(100));
}
})?;
}
Ok(())
}
fn send_query(
&self,
qname: &str,
qtype: QueryType,
server: (&str, u16),
recursive: bool,
) -> Result<DnsPacket> {
let packet = self.send_udp_query(qname, qtype, server, recursive)?;
if !packet.header.truncated_message {
return Ok(packet);
}
println!("Truncated response - resending as TCP");
self.send_tcp_query(qname, qtype, server, recursive)
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType};
pub type StubCallback = dyn Fn(&str, QueryType, (&str, u16), bool) -> Result<DnsPacket>;
pub struct DnsStubClient {
callback: Box<StubCallback>,
}
impl<'a> DnsStubClient {
pub fn new(callback: Box<StubCallback>) -> DnsStubClient {
DnsStubClient { callback: callback }
}
}
unsafe impl Send for DnsStubClient {}
unsafe impl Sync for DnsStubClient {}
impl DnsClient for DnsStubClient {
fn get_sent_count(&self) -> usize {
0
}
fn get_failed_count(&self) -> usize {
0
}
fn run(&self) -> Result<()> {
Ok(())
}
fn send_query(
&self,
qname: &str,
qtype: QueryType,
server: (&str, u16),
recursive: bool,
) -> Result<DnsPacket> {
(self.callback)(qname, qtype, server, recursive)
}
}
#[test]
pub fn test_udp_client() {
let client = DnsNetworkClient::new(31456);
client.run().unwrap();
let res = client
.send_udp_query("google.com", QueryType::A, ("8.8.8.8", 53), true)
.unwrap();
assert_eq!(res.questions[0].name, "google.com");
assert!(res.answers.len() > 0);
match res.answers[0] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
}
}
#[test]
pub fn test_tcp_client() {
let client = DnsNetworkClient::new(31457);
let res = client
.send_tcp_query("google.com", QueryType::A, ("8.8.8.8", 53), true)
.unwrap();
assert_eq!(res.questions[0].name, "google.com");
assert!(res.answers.len() > 0);
match res.answers[0] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
}
}
}

@ -0,0 +1,140 @@
//! The `ServerContext in this thread holds the common state across the server
use std::fs;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use derive_more::{Display, Error, From};
use crate::dns::authority::Authority;
use crate::dns::cache::SynchronizedCache;
use crate::dns::client::{DnsClient, DnsNetworkClient};
use crate::dns::resolve::{DnsResolver, ForwardingDnsResolver, RecursiveDnsResolver};
#[derive(Debug, Display, From, Error)]
pub enum ContextError {
Authority(crate::dns::authority::AuthorityError),
Client(crate::dns::client::ClientError),
Io(std::io::Error),
}
type Result<T> = std::result::Result<T, ContextError>;
pub struct ServerStatistics {
pub tcp_query_count: AtomicUsize,
pub udp_query_count: AtomicUsize,
}
impl ServerStatistics {
pub fn get_tcp_query_count(&self) -> usize {
self.tcp_query_count.load(Ordering::Acquire)
}
pub fn get_udp_query_count(&self) -> usize {
self.udp_query_count.load(Ordering::Acquire)
}
}
pub enum ResolveStrategy {
Recursive,
Forward { host: String, port: u16 },
}
pub struct ServerContext {
pub authority: Authority,
pub cache: SynchronizedCache,
pub client: Box<dyn DnsClient + Sync + Send>,
pub dns_port: u16,
pub api_port: u16,
pub resolve_strategy: ResolveStrategy,
pub allow_recursive: bool,
pub enable_udp: bool,
pub enable_tcp: bool,
pub enable_api: bool,
pub statistics: ServerStatistics,
pub zones_dir: &'static str
}
impl Default for ServerContext {
fn default() -> Self {
ServerContext::new()
}
}
impl ServerContext {
pub fn new() -> ServerContext {
ServerContext {
authority: Authority::new(),
cache: SynchronizedCache::new(),
client: Box::new(DnsNetworkClient::new(34255)),
dns_port: 53,
api_port: 5380,
resolve_strategy: ResolveStrategy::Recursive,
allow_recursive: true,
enable_udp: true,
enable_tcp: true,
enable_api: true,
statistics: ServerStatistics {
tcp_query_count: AtomicUsize::new(0),
udp_query_count: AtomicUsize::new(0),
},
zones_dir: "zones",
}
}
pub fn initialize(&mut self) -> Result<()> {
// Create zones directory if it doesn't exist
fs::create_dir_all(self.zones_dir)?;
// Start UDP client thread
self.client.run()?;
// Load authority data
self.authority.load()?;
Ok(())
}
pub fn create_resolver(&self, ptr: Arc<ServerContext>) -> Box<dyn DnsResolver> {
match self.resolve_strategy {
ResolveStrategy::Recursive => Box::new(RecursiveDnsResolver::new(ptr)),
ResolveStrategy::Forward { ref host, port } => {
Box::new(ForwardingDnsResolver::new(ptr, (host.clone(), port)))
}
}
}
}
#[cfg(test)]
pub mod tests {
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use crate::dns::authority::Authority;
use crate::dns::cache::SynchronizedCache;
use crate::dns::client::tests::{DnsStubClient, StubCallback};
use super::*;
pub fn create_test_context(callback: Box<StubCallback>) -> Arc<ServerContext> {
Arc::new(ServerContext {
authority: Authority::new(),
cache: SynchronizedCache::new(),
client: Box::new(DnsStubClient::new(callback)),
dns_port: 53,
api_port: 5380,
resolve_strategy: ResolveStrategy::Recursive,
allow_recursive: true,
enable_udp: true,
enable_tcp: true,
enable_api: true,
statistics: ServerStatistics {
tcp_query_count: AtomicUsize::new(0),
udp_query_count: AtomicUsize::new(0),
},
zones_dir: "zones",
})
}
}

@ -0,0 +1,26 @@
/*
Copyright 2018 Emil Hernvall
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
*/
//! The dns module implements the DNS protocol and the related functions
pub mod authority;
pub mod buffer;
pub mod cache;
pub mod client;
pub mod context;
pub mod protocol;
pub mod resolve;
pub mod server;
mod netutil;

@ -0,0 +1,19 @@
use std::io::{Read, Result, Write};
use std::net::TcpStream;
pub fn read_packet_length(stream: &mut TcpStream) -> Result<u16> {
let mut len_buffer = [0; 2];
stream.read(&mut len_buffer)?;
Ok(((len_buffer[0] as u16) << 8) | (len_buffer[1] as u16))
}
pub fn write_packet_length(stream: &mut TcpStream, len: usize) -> Result<()> {
let mut len_buffer = [0; 2];
len_buffer[0] = (len >> 8) as u8;
len_buffer[1] = (len & 0xFF) as u8;
stream.write(&len_buffer)?;
Ok(())
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,569 @@
//! resolver implementations implementing different strategies for answering
//! incoming queries
use std::sync::Arc;
use std::vec::Vec;
use derive_more::{Display, Error, From};
use crate::dns::context::ServerContext;
use crate::dns::protocol::{DnsPacket, QueryType, ResultCode};
#[derive(Debug, Display, From, Error)]
pub enum ResolveError {
Client(crate::dns::client::ClientError),
Cache(crate::dns::cache::CacheError),
Io(std::io::Error),
NoServerFound,
}
type Result<T> = std::result::Result<T, ResolveError>;
pub trait DnsResolver {
fn get_context(&self) -> Arc<ServerContext>;
fn resolve(&mut self, qname: &str, qtype: QueryType, recursive: bool) -> Result<DnsPacket> {
if let QueryType::UNKNOWN(_) = qtype {
let mut packet = DnsPacket::new();
packet.header.rescode = ResultCode::NOTIMP;
return Ok(packet);
}
let context = self.get_context();
if let Some(qr) = context.authority.query(qname, qtype) {
return Ok(qr);
}
if !recursive || !context.allow_recursive {
let mut packet = DnsPacket::new();
packet.header.rescode = ResultCode::REFUSED;
return Ok(packet);
}
if let Some(qr) = context.cache.lookup(qname, qtype) {
return Ok(qr);
}
if qtype == QueryType::A || qtype == QueryType::AAAA {
if let Some(qr) = context.cache.lookup(qname, QueryType::CNAME) {
return Ok(qr);
}
}
self.perform(qname, qtype)
}
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket>;
}
/// A Forwarding DNS Resolver
///
/// This resolver uses an external DNS server to service a query
pub struct ForwardingDnsResolver {
context: Arc<ServerContext>,
server: (String, u16),
}
impl ForwardingDnsResolver {
pub fn new(context: Arc<ServerContext>, server: (String, u16)) -> ForwardingDnsResolver {
ForwardingDnsResolver {
context: context,
server: server,
}
}
}
impl DnsResolver for ForwardingDnsResolver {
fn get_context(&self) -> Arc<ServerContext> {
self.context.clone()
}
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket> {
let &(ref host, port) = &self.server;
let result = self
.context
.client
.send_query(qname, qtype, (host.as_str(), port), true)?;
self.context.cache.store(&result.answers)?;
Ok(result)
}
}
/// A Recursive DNS resolver
///
/// This resolver can answer any request using the root servers of the internet
pub struct RecursiveDnsResolver {
context: Arc<ServerContext>,
}
impl RecursiveDnsResolver {
pub fn new(context: Arc<ServerContext>) -> RecursiveDnsResolver {
RecursiveDnsResolver { context: context }
}
}
impl DnsResolver for RecursiveDnsResolver {
fn get_context(&self) -> Arc<ServerContext> {
self.context.clone()
}
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// Find the closest name server by splitting the label and progessively
// moving towards the root servers. I.e. check "google.com", then "com",
// and finally "".
let mut tentative_ns = None;
let labels = qname.split('.').collect::<Vec<&str>>();
for lbl_idx in 0..labels.len() + 1 {
let domain = labels[lbl_idx..].join(".");
match self
.context
.cache
.lookup(&domain, QueryType::NS)
.and_then(|qr| qr.get_unresolved_ns(&domain))
.and_then(|ns| self.context.cache.lookup(&ns, QueryType::A))
.and_then(|qr| qr.get_random_a())
{
Some(addr) => {
tentative_ns = Some(addr);
break;
}
None => continue,
}
}
let mut ns = tentative_ns.ok_or_else(|| ResolveError::NoServerFound)?;
// Start querying name servers
loop {
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
let ns_copy = ns.clone();
let server = (ns_copy.as_str(), 53);
let response = self
.context
.client
.send_query(qname, qtype.clone(), server, false)?;
// If we've got an actual answer, we're done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
let _ = self.context.cache.store(&response.answers);
let _ = self.context.cache.store(&response.authorities);
let _ = self.context.cache.store(&response.resources);
return Ok(response.clone());
}
if response.header.rescode == ResultCode::NXDOMAIN {
if let Some(ttl) = response.get_ttl_from_soa() {
let _ = self.context.cache.store_nxdomain(qname, qtype, ttl);
}
return Ok(response.clone());
}
// Otherwise, try to find a new nameserver based on NS and a
// corresponding A record in the additional section
if let Some(new_ns) = response.get_resolved_ns(qname) {
// If there is such a record, we can retry the loop with that NS
ns = new_ns.clone();
let _ = self.context.cache.store(&response.answers);
let _ = self.context.cache.store(&response.authorities);
let _ = self.context.cache.store(&response.resources);
continue;
}
// If not, we'll have to resolve the ip of a NS record
let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x,
None => return Ok(response.clone()),
};
// Recursively resolve the NS
let recursive_response = self.resolve(&new_ns_name, QueryType::A, true)?;
// Pick a random IP and restart
if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns.clone();
} else {
return Ok(response.clone());
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl};
use super::*;
use crate::dns::context::tests::create_test_context;
use crate::dns::context::ResolveStrategy;
#[test]
fn test_forwarding_resolver() {
let mut context = create_test_context(Box::new(|qname, _, _, _| {
let mut packet = DnsPacket::new();
if qname == "google.com" {
packet.answers.push(DnsRecord::A {
domain: "google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
} else {
packet.header.rescode = ResultCode::NXDOMAIN;
}
Ok(packet)
}));
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
host: "127.0.0.1".to_string(),
port: 53,
};
}
None => panic!(),
}
let mut resolver = context.create_resolver(context.clone());
// First verify that we get a match back
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(1, res.answers.len());
match res.answers[0] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
}
};
// Do the same lookup again, and verify that it's present in the cache
// and that the counter has been updated
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(1, res.answers.len());
let list = match context.cache.list() {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(1, list.len());
assert_eq!("google.com", list[0].domain);
assert_eq!(1, list[0].record_types.len());
assert_eq!(1, list[0].hits);
};
// Do a failed lookup
{
let res = match resolver.resolve("yahoo.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(0, res.answers.len());
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
};
}
#[test]
fn test_recursive_resolver_with_no_nameserver() {
let context = create_test_context(Box::new(|_, _, _, _| {
let mut packet = DnsPacket::new();
packet.header.rescode = ResultCode::NXDOMAIN;
Ok(packet)
}));
let mut resolver = context.create_resolver(context.clone());
// Expect failure when no name servers are available
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
panic!();
}
}
#[test]
fn test_recursive_resolver_with_missing_a_record() {
let context = create_test_context(Box::new(|_, _, _, _| {
let mut packet = DnsPacket::new();
packet.header.rescode = ResultCode::NXDOMAIN;
Ok(packet)
}));
let mut resolver = context.create_resolver(context.clone());
// Expect failure when no name servers are available
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
panic!();
}
// Insert name server, but no corresponding A record
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "".to_string(),
host: "a.myroot.net".to_string(),
ttl: TransientTtl(3600),
});
let _ = context.cache.store(&nameservers);
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
panic!();
}
}
#[test]
fn test_recursive_resolver_match_order() {
let context = create_test_context(Box::new(|_, _, (server, _), _| {
let mut packet = DnsPacket::new();
if server == "127.0.0.1" {
packet.header.id = 1;
packet.answers.push(DnsRecord::A {
domain: "a.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
return Ok(packet);
} else if server == "127.0.0.2" {
packet.header.id = 2;
packet.answers.push(DnsRecord::A {
domain: "b.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
return Ok(packet);
} else if server == "127.0.0.3" {
packet.header.id = 3;
packet.answers.push(DnsRecord::A {
domain: "c.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
return Ok(packet);
}
packet.header.id = 999;
packet.header.rescode = ResultCode::NXDOMAIN;
Ok(packet)
}));
let mut resolver = context.create_resolver(context.clone());
// Expect failure when no name servers are available
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
panic!();
}
// Insert root servers
{
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "".to_string(),
host: "a.myroot.net".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::A {
domain: "a.myroot.net".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
let _ = context.cache.store(&nameservers);
}
match resolver.resolve("google.com", QueryType::A, true) {
Ok(packet) => {
assert_eq!(1, packet.header.id);
}
Err(_) => panic!(),
}
// Insert TLD servers
{
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "com".to_string(),
host: "a.mytld.net".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::A {
domain: "a.mytld.net".to_string(),
addr: "127.0.0.2".parse().unwrap(),
ttl: TransientTtl(3600),
});
let _ = context.cache.store(&nameservers);
}
match resolver.resolve("google.com", QueryType::A, true) {
Ok(packet) => {
assert_eq!(2, packet.header.id);
}
Err(_) => panic!(),
}
// Insert authoritative servers
{
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "google.com".to_string(),
host: "ns1.google.com".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::A {
domain: "ns1.google.com".to_string(),
addr: "127.0.0.3".parse().unwrap(),
ttl: TransientTtl(3600),
});
let _ = context.cache.store(&nameservers);
}
match resolver.resolve("google.com", QueryType::A, true) {
Ok(packet) => {
assert_eq!(3, packet.header.id);
}
Err(_) => panic!(),
}
}
#[test]
fn test_recursive_resolver_successfully() {
let context = create_test_context(Box::new(|qname, _, _, _| {
let mut packet = DnsPacket::new();
if qname == "google.com" {
packet.answers.push(DnsRecord::A {
domain: "google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
} else {
packet.header.rescode = ResultCode::NXDOMAIN;
packet.authorities.push(DnsRecord::SOA {
domain: "google.com".to_string(),
r_name: "google.com".to_string(),
m_name: "google.com".to_string(),
serial: 0,
refresh: 3600,
retry: 3600,
expire: 3600,
minimum: 3600,
ttl: TransientTtl(3600),
});
}
Ok(packet)
}));
let mut resolver = context.create_resolver(context.clone());
// Insert name servers
let mut nameservers = Vec::new();
nameservers.push(DnsRecord::NS {
domain: "google.com".to_string(),
host: "ns1.google.com".to_string(),
ttl: TransientTtl(3600),
});
nameservers.push(DnsRecord::A {
domain: "ns1.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600),
});
let _ = context.cache.store(&nameservers);
// Check that we can successfully resolve
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(1, res.answers.len());
match res.answers[0] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
}
};
// And that we won't find anything for a domain that isn't present
{
let res = match resolver.resolve("foobar.google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Perform another successful query, that should hit the cache
{
let res = match resolver.resolve("google.com", QueryType::A, true) {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(1, res.answers.len());
};
// Now check that the cache is used, and that the statistics is correct
{
let list = match context.cache.list() {
Ok(x) => x,
Err(_) => panic!(),
};
assert_eq!(3, list.len());
// Check statistics for google entry
assert_eq!("google.com", list[1].domain);
// Should have a NS record and an A record for a total of 2 record types
assert_eq!(2, list[1].record_types.len());
// Should have been hit two times for NS google.com and once for
// A google.com
assert_eq!(3, list[1].hits);
assert_eq!("ns1.google.com", list[2].domain);
assert_eq!(1, list[2].record_types.len());
assert_eq!(2, list[2].hits);
};
}
}

@ -0,0 +1,608 @@
//! UDP and TCP server implementations for DNS
use std::collections::VecDeque;
use std::io::Write;
use std::net::SocketAddr;
use std::net::{Shutdown, TcpListener, TcpStream, UdpSocket};
use std::sync::atomic::Ordering;
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::Builder;
use derive_more::{Display, Error, From};
use rand::random;
use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer, VectorPacketBuffer};
use crate::dns::context::ServerContext;
use crate::dns::netutil::{read_packet_length, write_packet_length};
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
use crate::dns::resolve::DnsResolver;
#[derive(Debug, Display, From, Error)]
pub enum ServerError {
Io(std::io::Error),
}
type Result<T> = std::result::Result<T, ServerError>;
macro_rules! return_or_report {
( $x:expr, $message:expr ) => {
match $x {
Ok(res) => res,
Err(_) => {
println!($message);
return;
}
}
};
}
macro_rules! ignore_or_report {
( $x:expr, $message:expr ) => {
match $x {
Ok(_) => {}
Err(_) => {
println!($message);
return;
}
};
};
}
/// Common trait for DNS servers
pub trait DnsServer {
/// Initialize the server and start listenening
///
/// This method should _NOT_ block. Rather, servers are expected to spawn a new
/// thread to handle requests and return immediately.
fn run_server(self) -> Result<()>;
}
/// Utility function for resolving domains referenced in for example CNAME or SRV
/// records. This usually spares the client from having to perform additional
/// lookups.
fn resolve_cnames(
lookup_list: &[DnsRecord],
results: &mut Vec<DnsPacket>,
resolver: &mut Box<dyn DnsResolver>,
depth: u16,
) {
if depth > 10 {
return;
}
for ref rec in lookup_list {
match **rec {
DnsRecord::CNAME { ref host, .. } | DnsRecord::SRV { ref host, .. } => {
if let Ok(result2) = resolver.resolve(host, QueryType::A, true) {
let new_unmatched = result2.get_unresolved_cnames();
results.push(result2);
resolve_cnames(&new_unmatched, results, resolver, depth + 1);
}
}
_ => {}
}
}
}
/// Perform the actual work for a query
///
/// Incoming requests are validated to make sure they are well formed and adhere
/// to the server configuration. If so, the request will be passed on to the
/// active resolver and a query will be performed. It will also resolve some
/// possible references within the query, such as CNAME hosts.
///
/// This function will always return a valid packet, even if the request could not
/// be performed, since we still want to send something back to the client.
pub fn execute_query(context: Arc<ServerContext>, request: &DnsPacket) -> DnsPacket {
let mut packet = DnsPacket::new();
packet.header.id = request.header.id;
packet.header.recursion_available = context.allow_recursive;
packet.header.response = true;
if request.header.recursion_desired && !context.allow_recursive {
packet.header.rescode = ResultCode::REFUSED;
} else if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR;
} else {
let mut results = Vec::new();
let question = &request.questions[0];
packet.questions.push(question.clone());
let mut resolver = context.create_resolver(context.clone());
let rescode = match resolver.resolve(
&question.name,
question.qtype,
request.header.recursion_desired,
) {
Ok(result) => {
let rescode = result.header.rescode;
let unmatched = result.get_unresolved_cnames();
results.push(result);
resolve_cnames(&unmatched, &mut results, &mut resolver, 0);
rescode
}
Err(err) => {
println!(
"Failed to resolve {:?} {}: {:?}",
question.qtype, question.name, err
);
ResultCode::SERVFAIL
}
};
packet.header.rescode = rescode;
for result in results {
for rec in result.answers {
packet.answers.push(rec);
}
for rec in result.authorities {
packet.authorities.push(rec);
}
for rec in result.resources {
packet.resources.push(rec);
}
}
}
packet
}
/// The UDP server
///
/// Accepts DNS queries through UDP, and uses the `ServerContext` to determine
/// how to service the request. Packets are read on a single thread, after which
/// a new thread is spawned to service the request asynchronously.
pub struct DnsUdpServer {
context: Arc<ServerContext>,
request_queue: Arc<Mutex<VecDeque<(SocketAddr, DnsPacket)>>>,
request_cond: Arc<Condvar>,
thread_count: usize,
}
impl DnsUdpServer {
pub fn new(context: Arc<ServerContext>, thread_count: usize) -> DnsUdpServer {
DnsUdpServer {
context: context,
request_queue: Arc::new(Mutex::new(VecDeque::new())),
request_cond: Arc::new(Condvar::new()),
thread_count: thread_count,
}
}
}
impl DnsServer for DnsUdpServer {
/// Launch the server
///
/// This method takes ownership of the server, preventing the method from
/// being called multiple times.
fn run_server(self) -> Result<()> {
// Bind the socket
let socket = UdpSocket::bind(("0.0.0.0", self.context.dns_port))?;
// Spawn threads for handling requests
for thread_id in 0..self.thread_count {
let socket_clone = match socket.try_clone() {
Ok(x) => x,
Err(e) => {
println!("Failed to clone socket when starting UDP server: {:?}", e);
continue;
}
};
let context = self.context.clone();
let request_cond = self.request_cond.clone();
let request_queue = self.request_queue.clone();
let name = "DnsUdpServer-request-".to_string() + &thread_id.to_string();
let _ = Builder::new().name(name).spawn(move || {
loop {
// Acquire lock, and wait on the condition until data is
// available. Then proceed with popping an entry of the queue.
let (src, request) = match request_queue
.lock()
.ok()
.and_then(|x| request_cond.wait(x).ok())
.and_then(|mut x| x.pop_front())
{
Some(x) => x,
None => {
println!("Not expected to happen!");
continue;
}
};
let mut size_limit = 512;
// Check for EDNS
if request.resources.len() == 1 {
if let DnsRecord::OPT { packet_len, .. } = request.resources[0] {
size_limit = packet_len as usize;
}
}
// Create a response buffer, and ask the context for an appropriate
// resolver
let mut res_buffer = VectorPacketBuffer::new();
let mut packet = execute_query(context.clone(), &request);
let _ = packet.write(&mut res_buffer, size_limit);
// Fire off the response
let len = res_buffer.pos();
let data = return_or_report!(
res_buffer.get_range(0, len),
"Failed to get buffer data"
);
ignore_or_report!(
socket_clone.send_to(data, src),
"Failed to send response packet"
);
}
})?;
}
// Start servicing requests
let _ = Builder::new()
.name("DnsUdpServer-incoming".into())
.spawn(move || {
loop {
let _ = self
.context
.statistics
.udp_query_count
.fetch_add(1, Ordering::Release);
// Read a query packet
let mut req_buffer = BytePacketBuffer::new();
let (_, src) = match socket.recv_from(&mut req_buffer.buf) {
Ok(x) => x,
Err(e) => {
println!("Failed to read from UDP socket: {:?}", e);
continue;
}
};
// Parse it
let request = match DnsPacket::from_buffer(&mut req_buffer) {
Ok(x) => x,
Err(e) => {
println!("Failed to parse UDP query packet: {:?}", e);
continue;
}
};
// Acquire lock, add request to queue, and notify waiting threads
// using the condition.
match self.request_queue.lock() {
Ok(mut queue) => {
queue.push_back((src, request));
self.request_cond.notify_one();
}
Err(e) => {
println!("Failed to send UDP request for processing: {}", e);
}
}
}
})?;
Ok(())
}
}
/// TCP DNS server
pub struct DnsTcpServer {
context: Arc<ServerContext>,
senders: Vec<Sender<TcpStream>>,
thread_count: usize,
}
impl DnsTcpServer {
pub fn new(context: Arc<ServerContext>, thread_count: usize) -> DnsTcpServer {
DnsTcpServer {
context: context,
senders: Vec::new(),
thread_count: thread_count,
}
}
}
impl DnsServer for DnsTcpServer {
fn run_server(mut self) -> Result<()> {
let socket = TcpListener::bind(("0.0.0.0", self.context.dns_port))?;
// Spawn threads for handling requests, and create the channels
for thread_id in 0..self.thread_count {
let (tx, rx) = channel();
self.senders.push(tx);
let context = self.context.clone();
let name = "DnsTcpServer-request-".to_string() + &thread_id.to_string();
let _ = Builder::new().name(name).spawn(move || {
loop {
let mut stream = match rx.recv() {
Ok(x) => x,
Err(_) => continue,
};
let _ = context
.statistics
.tcp_query_count
.fetch_add(1, Ordering::Release);
// When DNS packets are sent over TCP, they're prefixed with a two byte
// length. We don't really need to know the length in advance, so we
// just move past it and continue reading as usual
ignore_or_report!(
read_packet_length(&mut stream),
"Failed to read query packet length"
);
let request = {
let mut stream_buffer = StreamPacketBuffer::new(&mut stream);
return_or_report!(
DnsPacket::from_buffer(&mut stream_buffer),
"Failed to read query packet"
)
};
let mut res_buffer = VectorPacketBuffer::new();
let mut packet = execute_query(context.clone(), &request);
ignore_or_report!(
packet.write(&mut res_buffer, 0xFFFF),
"Failed to write packet to buffer"
);
// As is the case for incoming queries, we need to send a 2 byte length
// value before handing of the actual packet.
let len = res_buffer.pos();
ignore_or_report!(
write_packet_length(&mut stream, len),
"Failed to write packet size"
);
// Now we can go ahead and write the actual packet
let data = return_or_report!(
res_buffer.get_range(0, len),
"Failed to get packet data"
);
ignore_or_report!(stream.write(data), "Failed to write response packet");
ignore_or_report!(stream.shutdown(Shutdown::Both), "Failed to shutdown socket");
}
})?;
}
let _ = Builder::new()
.name("DnsTcpServer-incoming".into())
.spawn(move || {
for wrap_stream in socket.incoming() {
let stream = match wrap_stream {
Ok(stream) => stream,
Err(err) => {
println!("Failed to accept TCP connection: {:?}", err);
continue;
}
};
// Hand it off to a worker thread
let thread_no = random::<usize>() % self.thread_count;
match self.senders[thread_no].send(stream) {
Ok(_) => {}
Err(e) => {
println!(
"Failed to send TCP request for processing on thread {}: {}",
thread_no, e
);
}
}
}
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use std::sync::Arc;
use crate::dns::protocol::{
DnsPacket, DnsQuestion, DnsRecord, QueryType, ResultCode, TransientTtl,
};
use super::*;
use crate::dns::context::tests::create_test_context;
use crate::dns::context::ResolveStrategy;
fn build_query(qname: &str, qtype: QueryType) -> DnsPacket {
let mut query_packet = DnsPacket::new();
query_packet.header.recursion_desired = true;
query_packet
.questions
.push(DnsQuestion::new(qname.into(), qtype));
query_packet
}
#[test]
fn test_execute_query() {
// Construct a context to execute some queries successfully
let mut context = create_test_context(Box::new(|qname, qtype, _, _| {
let mut packet = DnsPacket::new();
if qname == "google.com" {
packet.answers.push(DnsRecord::A {
domain: "google.com".to_string(),
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
ttl: TransientTtl(3600),
});
} else if qname == "www.facebook.com" && qtype == QueryType::CNAME {
packet.answers.push(DnsRecord::CNAME {
domain: "www.facebook.com".to_string(),
host: "cdn.facebook.com".to_string(),
ttl: TransientTtl(3600),
});
packet.answers.push(DnsRecord::A {
domain: "cdn.facebook.com".to_string(),
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
ttl: TransientTtl(3600),
});
} else if qname == "www.microsoft.com" && qtype == QueryType::CNAME {
packet.answers.push(DnsRecord::CNAME {
domain: "www.microsoft.com".to_string(),
host: "cdn.microsoft.com".to_string(),
ttl: TransientTtl(3600),
});
} else if qname == "cdn.microsoft.com" && qtype == QueryType::A {
packet.answers.push(DnsRecord::A {
domain: "cdn.microsoft.com".to_string(),
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
ttl: TransientTtl(3600),
});
} else {
packet.header.rescode = ResultCode::NXDOMAIN;
}
Ok(packet)
}));
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
host: "127.0.0.1".to_string(),
port: 53,
};
}
None => panic!(),
}
// A successful resolve
{
let res = execute_query(context.clone(), &build_query("google.com", QueryType::A));
assert_eq!(1, res.answers.len());
match res.answers[0] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!(),
}
};
// A successful resolve, that also resolves a CNAME without recursive lookup
{
let res = execute_query(
context.clone(),
&build_query("www.facebook.com", QueryType::CNAME),
);
assert_eq!(2, res.answers.len());
match res.answers[0] {
DnsRecord::CNAME { ref domain, .. } => {
assert_eq!("www.facebook.com", domain);
}
_ => panic!(),
}
match res.answers[1] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("cdn.facebook.com", domain);
}
_ => panic!(),
}
};
// A successful resolve, that also resolves a CNAME through recursive lookup
{
let res = execute_query(
context.clone(),
&build_query("www.microsoft.com", QueryType::CNAME),
);
assert_eq!(2, res.answers.len());
match res.answers[0] {
DnsRecord::CNAME { ref domain, .. } => {
assert_eq!("www.microsoft.com", domain);
}
_ => panic!(),
}
match res.answers[1] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("cdn.microsoft.com", domain);
}
_ => panic!(),
}
};
// An unsuccessful resolve, but without any error
{
let res = execute_query(context.clone(), &build_query("yahoo.com", QueryType::A));
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Disable recursive resolves to generate a failure
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.allow_recursive = false;
}
None => panic!(),
}
// This should generate an error code, since recursive resolves are
// no longer allowed
{
let res = execute_query(context.clone(), &build_query("yahoo.com", QueryType::A));
assert_eq!(ResultCode::REFUSED, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Send a query without a question, which should fail with an error code
{
let query_packet = DnsPacket::new();
let res = execute_query(context.clone(), &query_packet);
assert_eq!(ResultCode::FORMERR, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Now construct a context where the dns client will return a failure
let mut context2 = create_test_context(Box::new(|_, _, _, _| {
Err(crate::dns::client::ClientError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Fail",
)))
}));
match Arc::get_mut(&mut context2) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward {
host: "127.0.0.1".to_string(),
port: 53,
};
}
None => panic!(),
}
// We expect this to set the server failure rescode
{
let res = execute_query(context2.clone(), &build_query("yahoo.com", QueryType::A));
assert_eq!(ResultCode::SERVFAIL, res.header.rescode);
assert_eq!(0, res.answers.len());
};
}
}

@ -17,4 +17,5 @@ pub mod miner;
pub mod context;
pub mod event;
pub mod p2p;
pub mod dns;

Loading…
Cancel
Save