mirror of https://github.com/Revertron/Alfis
First DNS compile. Took DNS code from https://github.com/EmilHernvall/hermes.
parent
b4ae51088d
commit
4b5e5112da
@ -0,0 +1,257 @@
|
||||
//! contains the data store for local zones
|
||||
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
use derive_more::{Display, From, Error};
|
||||
|
||||
use crate::dns::buffer::{PacketBuffer, StreamPacketBuffer, VectorPacketBuffer};
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum AuthorityError {
|
||||
Buffer(crate::dns::buffer::BufferError),
|
||||
Protocol(crate::dns::protocol::ProtocolError),
|
||||
Io(std::io::Error),
|
||||
PoisonedLock,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, AuthorityError>;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Zone {
|
||||
pub domain: String,
|
||||
pub m_name: String,
|
||||
pub r_name: String,
|
||||
pub serial: u32,
|
||||
pub refresh: u32,
|
||||
pub retry: u32,
|
||||
pub expire: u32,
|
||||
pub minimum: u32,
|
||||
pub records: BTreeSet<DnsRecord>,
|
||||
}
|
||||
|
||||
impl Zone {
|
||||
pub fn new(domain: String, m_name: String, r_name: String) -> Zone {
|
||||
Zone {
|
||||
domain: domain,
|
||||
m_name: m_name,
|
||||
r_name: r_name,
|
||||
serial: 0,
|
||||
refresh: 0,
|
||||
retry: 0,
|
||||
expire: 0,
|
||||
minimum: 0,
|
||||
records: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_record(&mut self, rec: &DnsRecord) -> bool {
|
||||
self.records.insert(rec.clone())
|
||||
}
|
||||
|
||||
pub fn delete_record(&mut self, rec: &DnsRecord) -> bool {
|
||||
self.records.remove(rec)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Zones {
|
||||
zones: BTreeMap<String, Zone>,
|
||||
}
|
||||
|
||||
impl<'a> Zones {
|
||||
pub fn new() -> Zones {
|
||||
Zones {
|
||||
zones: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(&mut self) -> Result<()> {
|
||||
let zones_dir = Path::new("zones").read_dir()?;
|
||||
|
||||
for wrapped_filename in zones_dir {
|
||||
let filename = match wrapped_filename {
|
||||
Ok(x) => x,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let mut zone_file = match File::open(filename.path()) {
|
||||
Ok(x) => x,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let mut buffer = StreamPacketBuffer::new(&mut zone_file);
|
||||
|
||||
let mut zone = Zone::new(String::new(), String::new(), String::new());
|
||||
buffer.read_qname(&mut zone.domain)?;
|
||||
buffer.read_qname(&mut zone.m_name)?;
|
||||
buffer.read_qname(&mut zone.r_name)?;
|
||||
zone.serial = buffer.read_u32()?;
|
||||
zone.refresh = buffer.read_u32()?;
|
||||
zone.retry = buffer.read_u32()?;
|
||||
zone.expire = buffer.read_u32()?;
|
||||
zone.minimum = buffer.read_u32()?;
|
||||
|
||||
let record_count = buffer.read_u32()?;
|
||||
|
||||
for _ in 0..record_count {
|
||||
let rr = DnsRecord::read(&mut buffer)?;
|
||||
zone.add_record(&rr);
|
||||
}
|
||||
|
||||
println!("Loaded zone {} with {} records", zone.domain, record_count);
|
||||
|
||||
self.zones.insert(zone.domain.clone(), zone);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save(&mut self) -> Result<()> {
|
||||
let zones_dir = Path::new("zones");
|
||||
for zone in self.zones.values() {
|
||||
let filename = zones_dir.join(Path::new(&zone.domain));
|
||||
let mut zone_file = match File::create(&filename) {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
println!("Failed to save file {:?}", filename);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut buffer = VectorPacketBuffer::new();
|
||||
let _ = buffer.write_qname(&zone.domain);
|
||||
let _ = buffer.write_qname(&zone.m_name);
|
||||
let _ = buffer.write_qname(&zone.r_name);
|
||||
let _ = buffer.write_u32(zone.serial);
|
||||
let _ = buffer.write_u32(zone.refresh);
|
||||
let _ = buffer.write_u32(zone.retry);
|
||||
let _ = buffer.write_u32(zone.expire);
|
||||
let _ = buffer.write_u32(zone.minimum);
|
||||
let _ = buffer.write_u32(zone.records.len() as u32);
|
||||
|
||||
for rec in &zone.records {
|
||||
let _ = rec.write(&mut buffer);
|
||||
}
|
||||
|
||||
let _ = zone_file.write(&buffer.buffer[0..buffer.pos]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn zones(&self) -> Vec<&Zone> {
|
||||
self.zones.values().collect()
|
||||
}
|
||||
|
||||
pub fn add_zone(&mut self, zone: Zone) {
|
||||
self.zones.insert(zone.domain.clone(), zone);
|
||||
}
|
||||
|
||||
pub fn get_zone(&'a self, domain: &str) -> Option<&'a Zone> {
|
||||
self.zones.get(domain)
|
||||
}
|
||||
|
||||
pub fn get_zone_mut(&'a mut self, domain: &str) -> Option<&'a mut Zone> {
|
||||
self.zones.get_mut(domain)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Authority {
|
||||
zones: RwLock<Zones>,
|
||||
}
|
||||
|
||||
impl Authority {
|
||||
pub fn new() -> Authority {
|
||||
Authority {
|
||||
zones: RwLock::new(Zones::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(&self) -> Result<()> {
|
||||
let mut zones = self
|
||||
.zones
|
||||
.write()
|
||||
.map_err(|_| AuthorityError::PoisonedLock)?;
|
||||
zones.load()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn query(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
let zones = match self.zones.read().ok() {
|
||||
Some(x) => x,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let mut best_match = None;
|
||||
for zone in zones.zones() {
|
||||
if !qname.ends_with(&zone.domain) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((len, _)) = best_match {
|
||||
if len < zone.domain.len() {
|
||||
best_match = Some((zone.domain.len(), zone));
|
||||
}
|
||||
} else {
|
||||
best_match = Some((zone.domain.len(), zone));
|
||||
}
|
||||
}
|
||||
|
||||
let zone = match best_match {
|
||||
Some((_, zone)) => zone,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.authoritative_answer = true;
|
||||
|
||||
for rec in &zone.records {
|
||||
let domain = match rec.get_domain() {
|
||||
Some(x) => x,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if &domain != qname {
|
||||
continue;
|
||||
}
|
||||
|
||||
let rtype = rec.get_querytype();
|
||||
if qtype == rtype || (qtype == QueryType::A && rtype == QueryType::CNAME) {
|
||||
packet.answers.push(rec.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if packet.answers.is_empty() {
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
|
||||
packet.authorities.push(DnsRecord::SOA {
|
||||
domain: zone.domain.clone(),
|
||||
m_name: zone.m_name.clone(),
|
||||
r_name: zone.r_name.clone(),
|
||||
serial: zone.serial,
|
||||
refresh: zone.refresh,
|
||||
retry: zone.retry,
|
||||
expire: zone.expire,
|
||||
minimum: zone.minimum,
|
||||
ttl: TransientTtl(zone.minimum),
|
||||
});
|
||||
}
|
||||
|
||||
Some(packet)
|
||||
}
|
||||
|
||||
pub fn read(&self) -> LockResult<RwLockReadGuard<'_, Zones>> {
|
||||
self.zones.read()
|
||||
}
|
||||
|
||||
pub fn write(&self) -> LockResult<RwLockWriteGuard<'_, Zones>> {
|
||||
self.zones.write()
|
||||
}
|
||||
}
|
@ -0,0 +1,487 @@
|
||||
//! buffers for use when writing and reading dns packets
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::io::Read;
|
||||
|
||||
use derive_more::{Display, Error, From};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum BufferError {
|
||||
Io(std::io::Error),
|
||||
EndOfBuffer,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, BufferError>;
|
||||
|
||||
pub trait PacketBuffer {
|
||||
fn read(&mut self) -> Result<u8>;
|
||||
fn get(&mut self, pos: usize) -> Result<u8>;
|
||||
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]>;
|
||||
fn write(&mut self, val: u8) -> Result<()>;
|
||||
fn set(&mut self, pos: usize, val: u8) -> Result<()>;
|
||||
fn pos(&self) -> usize;
|
||||
fn seek(&mut self, pos: usize) -> Result<()>;
|
||||
fn step(&mut self, steps: usize) -> Result<()>;
|
||||
fn find_label(&self, label: &str) -> Option<usize>;
|
||||
fn save_label(&mut self, label: &str, pos: usize);
|
||||
|
||||
fn write_u8(&mut self, val: u8) -> Result<()> {
|
||||
self.write(val)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
|
||||
self.set(pos, (val >> 8) as u8)?;
|
||||
self.set(pos + 1, (val & 0xFF) as u8)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_u16(&mut self, val: u16) -> Result<()> {
|
||||
self.write((val >> 8) as u8)?;
|
||||
self.write((val & 0xFF) as u8)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_u32(&mut self, val: u32) -> Result<()> {
|
||||
self.write(((val >> 24) & 0xFF) as u8)?;
|
||||
self.write(((val >> 16) & 0xFF) as u8)?;
|
||||
self.write(((val >> 8) & 0xFF) as u8)?;
|
||||
self.write(((val >> 0) & 0xFF) as u8)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_qname(&mut self, qname: &str) -> Result<()> {
|
||||
let split_str = qname.split('.').collect::<Vec<&str>>();
|
||||
|
||||
let mut jump_performed = false;
|
||||
for (i, label) in split_str.iter().enumerate() {
|
||||
let search_lbl = split_str[i..split_str.len()].join(".");
|
||||
if let Some(prev_pos) = self.find_label(&search_lbl) {
|
||||
let jump_inst = (prev_pos as u16) | 0xC000;
|
||||
self.write_u16(jump_inst)?;
|
||||
jump_performed = true;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
let pos = self.pos();
|
||||
self.save_label(&search_lbl, pos);
|
||||
|
||||
let len = label.len();
|
||||
self.write_u8(len as u8)?;
|
||||
for b in label.as_bytes() {
|
||||
self.write_u8(*b)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !jump_performed {
|
||||
self.write_u8(0)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_u16(&mut self) -> Result<u16> {
|
||||
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn read_u32(&mut self) -> Result<u32> {
|
||||
let res = ((self.read()? as u32) << 24)
|
||||
| ((self.read()? as u32) << 16)
|
||||
| ((self.read()? as u32) << 8)
|
||||
| ((self.read()? as u32) << 0);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
|
||||
let mut pos = self.pos();
|
||||
let mut jumped = false;
|
||||
|
||||
let mut delim = "";
|
||||
loop {
|
||||
let len = self.get(pos)?;
|
||||
|
||||
// A two byte sequence, where the two highest bits of the first byte is
|
||||
// set, represents a offset relative to the start of the buffer. We
|
||||
// handle this by jumping to the offset, setting a flag to indicate
|
||||
// that we shouldn't update the shared buffer position once done.
|
||||
if (len & 0xC0) > 0 {
|
||||
// When a jump is performed, we only modify the shared buffer
|
||||
// position once, and avoid making the change later on.
|
||||
if !jumped {
|
||||
self.seek(pos + 2)?;
|
||||
}
|
||||
|
||||
let b2 = self.get(pos + 1)? as u16;
|
||||
let offset = (((len as u16) ^ 0xC0) << 8) | b2;
|
||||
pos = offset as usize;
|
||||
jumped = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
pos += 1;
|
||||
|
||||
// Names are terminated by an empty label of length 0
|
||||
if len == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
outstr.push_str(delim);
|
||||
|
||||
let str_buffer = self.get_range(pos, len as usize)?;
|
||||
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
|
||||
|
||||
delim = ".";
|
||||
|
||||
pos += len as usize;
|
||||
}
|
||||
|
||||
if !jumped {
|
||||
self.seek(pos)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct VectorPacketBuffer {
|
||||
pub buffer: Vec<u8>,
|
||||
pub pos: usize,
|
||||
pub label_lookup: BTreeMap<String, usize>,
|
||||
}
|
||||
|
||||
impl VectorPacketBuffer {
|
||||
pub fn new() -> VectorPacketBuffer {
|
||||
VectorPacketBuffer {
|
||||
buffer: Vec::new(),
|
||||
pos: 0,
|
||||
label_lookup: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketBuffer for VectorPacketBuffer {
|
||||
fn find_label(&self, label: &str) -> Option<usize> {
|
||||
self.label_lookup.get(label).cloned()
|
||||
}
|
||||
|
||||
fn save_label(&mut self, label: &str, pos: usize) {
|
||||
self.label_lookup.insert(label.to_string(), pos);
|
||||
}
|
||||
|
||||
fn read(&mut self) -> Result<u8> {
|
||||
let res = self.buffer[self.pos];
|
||||
self.pos += 1;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn get(&mut self, pos: usize) -> Result<u8> {
|
||||
Ok(self.buffer[pos])
|
||||
}
|
||||
|
||||
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||
Ok(&self.buffer[start..start + len as usize])
|
||||
}
|
||||
|
||||
fn write(&mut self, val: u8) -> Result<()> {
|
||||
self.buffer.push(val);
|
||||
self.pos += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set(&mut self, pos: usize, val: u8) -> Result<()> {
|
||||
self.buffer[pos] = val;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pos(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
|
||||
fn seek(&mut self, pos: usize) -> Result<()> {
|
||||
self.pos = pos;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn step(&mut self, steps: usize) -> Result<()> {
|
||||
self.pos += steps;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StreamPacketBuffer<'a, T>
|
||||
where
|
||||
T: Read,
|
||||
{
|
||||
pub stream: &'a mut T,
|
||||
pub buffer: Vec<u8>,
|
||||
pub pos: usize,
|
||||
}
|
||||
|
||||
impl<'a, T> StreamPacketBuffer<'a, T>
|
||||
where
|
||||
T: Read + 'a,
|
||||
{
|
||||
pub fn new(stream: &'a mut T) -> StreamPacketBuffer<'_, T> {
|
||||
StreamPacketBuffer {
|
||||
stream: stream,
|
||||
buffer: Vec::new(),
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> PacketBuffer for StreamPacketBuffer<'a, T>
|
||||
where
|
||||
T: Read + 'a,
|
||||
{
|
||||
fn find_label(&self, _: &str) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
fn save_label(&mut self, _: &str, _: usize) {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
fn read(&mut self) -> Result<u8> {
|
||||
while self.pos >= self.buffer.len() {
|
||||
let mut local_buffer = [0; 1];
|
||||
self.stream.read(&mut local_buffer)?;
|
||||
self.buffer.push(local_buffer[0]);
|
||||
}
|
||||
|
||||
let res = self.buffer[self.pos];
|
||||
self.pos += 1;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn get(&mut self, pos: usize) -> Result<u8> {
|
||||
while pos >= self.buffer.len() {
|
||||
let mut local_buffer = [0; 1];
|
||||
self.stream.read(&mut local_buffer)?;
|
||||
self.buffer.push(local_buffer[0]);
|
||||
}
|
||||
|
||||
Ok(self.buffer[pos])
|
||||
}
|
||||
|
||||
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||
while start + len > self.buffer.len() {
|
||||
let mut local_buffer = [0; 1];
|
||||
self.stream.read(&mut local_buffer)?;
|
||||
self.buffer.push(local_buffer[0]);
|
||||
}
|
||||
|
||||
Ok(&self.buffer[start..start + len as usize])
|
||||
}
|
||||
|
||||
fn write(&mut self, _: u8) -> Result<()> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
fn set(&mut self, _: usize, _: u8) -> Result<()> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
fn pos(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
|
||||
fn seek(&mut self, pos: usize) -> Result<()> {
|
||||
self.pos = pos;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn step(&mut self, steps: usize) -> Result<()> {
|
||||
self.pos += steps;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BytePacketBuffer {
|
||||
pub buf: [u8; 512],
|
||||
pub pos: usize,
|
||||
}
|
||||
|
||||
impl BytePacketBuffer {
|
||||
pub fn new() -> BytePacketBuffer {
|
||||
BytePacketBuffer {
|
||||
buf: [0; 512],
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BytePacketBuffer {
|
||||
fn default() -> Self {
|
||||
BytePacketBuffer::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketBuffer for BytePacketBuffer {
|
||||
fn find_label(&self, _: &str) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
|
||||
fn save_label(&mut self, _: &str, _: usize) {}
|
||||
|
||||
fn read(&mut self) -> Result<u8> {
|
||||
if self.pos >= 512 {
|
||||
return Err(BufferError::EndOfBuffer);
|
||||
}
|
||||
let res = self.buf[self.pos];
|
||||
self.pos += 1;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn get(&mut self, pos: usize) -> Result<u8> {
|
||||
if pos >= 512 {
|
||||
return Err(BufferError::EndOfBuffer);
|
||||
}
|
||||
Ok(self.buf[pos])
|
||||
}
|
||||
|
||||
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
|
||||
if start + len >= 512 {
|
||||
return Err(BufferError::EndOfBuffer);
|
||||
}
|
||||
Ok(&self.buf[start..start + len as usize])
|
||||
}
|
||||
|
||||
fn write(&mut self, val: u8) -> Result<()> {
|
||||
if self.pos >= 512 {
|
||||
return Err(BufferError::EndOfBuffer);
|
||||
}
|
||||
self.buf[self.pos] = val;
|
||||
self.pos += 1;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set(&mut self, pos: usize, val: u8) -> Result<()> {
|
||||
self.buf[pos] = val;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pos(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
|
||||
fn seek(&mut self, pos: usize) -> Result<()> {
|
||||
self.pos = pos;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn step(&mut self, steps: usize) -> Result<()> {
|
||||
self.pos += steps;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_qname() {
|
||||
let mut buffer = VectorPacketBuffer::new();
|
||||
|
||||
let instr1 = "a.google.com".to_string();
|
||||
let instr2 = "b.google.com".to_string();
|
||||
|
||||
// First write the standard string
|
||||
match buffer.write_qname(&instr1) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
// Then we set up a slight variation with relies on a jump back to the data of
|
||||
// the first name
|
||||
let crafted_data = [0x01, b'b' as u8, 0xC0, 0x02];
|
||||
for b in &crafted_data {
|
||||
match buffer.write_u8(*b) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the buffer position for reading
|
||||
buffer.pos = 0;
|
||||
|
||||
// Read the standard name
|
||||
let mut outstr1 = String::new();
|
||||
match buffer.read_qname(&mut outstr1) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
assert_eq!(instr1, outstr1);
|
||||
|
||||
// Read the name with a jump
|
||||
let mut outstr2 = String::new();
|
||||
match buffer.read_qname(&mut outstr2) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
assert_eq!(instr2, outstr2);
|
||||
|
||||
// Make sure we're now at the end of the buffer
|
||||
assert_eq!(buffer.pos, buffer.buffer.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_qname() {
|
||||
let mut buffer = VectorPacketBuffer::new();
|
||||
|
||||
match buffer.write_qname(&"ns1.google.com".to_string()) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
match buffer.write_qname(&"ns2.google.com".to_string()) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
assert_eq!(22, buffer.pos());
|
||||
|
||||
match buffer.seek(0) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
let mut str1 = String::new();
|
||||
match buffer.read_qname(&mut str1) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
assert_eq!("ns1.google.com", str1);
|
||||
|
||||
let mut str2 = String::new();
|
||||
match buffer.read_qname(&mut str2) {
|
||||
Ok(_) => {}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
assert_eq!("ns2.google.com", str2);
|
||||
}
|
||||
}
|
@ -0,0 +1,462 @@
|
||||
//! a threadsafe cache for DNS information
|
||||
|
||||
extern crate serde;
|
||||
use std::clone::Clone;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use chrono::*;
|
||||
use derive_more::{Display, Error, From};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum CacheError {
|
||||
Io(std::io::Error),
|
||||
PoisonedLock,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, CacheError>;
|
||||
|
||||
pub enum CacheState {
|
||||
PositiveCache,
|
||||
NegativeCache,
|
||||
NotCached,
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, Debug, Serialize, Deserialize)]
|
||||
pub struct RecordEntry {
|
||||
pub record: DnsRecord,
|
||||
pub timestamp: DateTime<Local>,
|
||||
}
|
||||
|
||||
impl PartialEq<RecordEntry> for RecordEntry {
|
||||
fn eq(&self, other: &RecordEntry) -> bool {
|
||||
self.record == other.record
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for RecordEntry {
|
||||
fn hash<H>(&self, state: &mut H)
|
||||
where
|
||||
H: Hasher,
|
||||
{
|
||||
self.record.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum RecordSet {
|
||||
NoRecords {
|
||||
qtype: QueryType,
|
||||
ttl: u32,
|
||||
timestamp: DateTime<Local>,
|
||||
},
|
||||
Records {
|
||||
qtype: QueryType,
|
||||
records: HashSet<RecordEntry>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DomainEntry {
|
||||
pub domain: String,
|
||||
pub record_types: HashMap<QueryType, RecordSet>,
|
||||
pub hits: u32,
|
||||
pub updates: u32,
|
||||
}
|
||||
|
||||
impl DomainEntry {
|
||||
pub fn new(domain: String) -> DomainEntry {
|
||||
DomainEntry {
|
||||
domain: domain,
|
||||
record_types: HashMap::new(),
|
||||
hits: 0,
|
||||
updates: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&mut self, qtype: QueryType, ttl: u32) {
|
||||
self.updates += 1;
|
||||
|
||||
let new_set = RecordSet::NoRecords {
|
||||
qtype: qtype,
|
||||
ttl: ttl,
|
||||
timestamp: Local::now(),
|
||||
};
|
||||
|
||||
self.record_types.insert(qtype, new_set);
|
||||
}
|
||||
|
||||
pub fn store_record(&mut self, rec: &DnsRecord) {
|
||||
self.updates += 1;
|
||||
|
||||
let entry = RecordEntry {
|
||||
record: rec.clone(),
|
||||
timestamp: Local::now(),
|
||||
};
|
||||
|
||||
if let Some(&mut RecordSet::Records {
|
||||
ref mut records, ..
|
||||
}) = self.record_types.get_mut(&rec.get_querytype())
|
||||
{
|
||||
if records.contains(&entry) {
|
||||
records.remove(&entry);
|
||||
}
|
||||
|
||||
records.insert(entry);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut records = HashSet::new();
|
||||
records.insert(entry);
|
||||
|
||||
let new_set = RecordSet::Records {
|
||||
qtype: rec.get_querytype(),
|
||||
records: records,
|
||||
};
|
||||
|
||||
self.record_types.insert(rec.get_querytype(), new_set);
|
||||
}
|
||||
|
||||
pub fn get_cache_state(&self, qtype: QueryType) -> CacheState {
|
||||
match self.record_types.get(&qtype) {
|
||||
Some(&RecordSet::Records { ref records, .. }) => {
|
||||
let now = Local::now();
|
||||
|
||||
let mut valid_count = 0;
|
||||
for entry in records {
|
||||
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
|
||||
let expires = entry.timestamp + ttl_offset;
|
||||
if expires < now {
|
||||
continue;
|
||||
}
|
||||
|
||||
if entry.record.get_querytype() == qtype {
|
||||
valid_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if valid_count > 0 {
|
||||
CacheState::PositiveCache
|
||||
} else {
|
||||
CacheState::NotCached
|
||||
}
|
||||
}
|
||||
Some(&RecordSet::NoRecords { ttl, timestamp, .. }) => {
|
||||
let now = Local::now();
|
||||
let ttl_offset = Duration::seconds(ttl as i64);
|
||||
let expires = timestamp + ttl_offset;
|
||||
|
||||
if expires < now {
|
||||
CacheState::NotCached
|
||||
} else {
|
||||
CacheState::NegativeCache
|
||||
}
|
||||
}
|
||||
None => CacheState::NotCached,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fill_queryresult(&self, qtype: QueryType, result_vec: &mut Vec<DnsRecord>) {
|
||||
let now = Local::now();
|
||||
|
||||
let current_set = match self.record_types.get(&qtype) {
|
||||
Some(x) => x,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if let RecordSet::Records { ref records, .. } = *current_set {
|
||||
for entry in records {
|
||||
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
|
||||
let expires = entry.timestamp + ttl_offset;
|
||||
if expires < now {
|
||||
continue;
|
||||
}
|
||||
|
||||
if entry.record.get_querytype() == qtype {
|
||||
result_vec.push(entry.record.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Cache {
|
||||
domain_entries: BTreeMap<String, Arc<DomainEntry>>,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new() -> Cache {
|
||||
Cache {
|
||||
domain_entries: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState {
|
||||
match self.domain_entries.get(qname) {
|
||||
Some(x) => x.get_cache_state(qtype),
|
||||
None => CacheState::NotCached,
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_queryresult(
|
||||
&mut self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
result_vec: &mut Vec<DnsRecord>,
|
||||
increment_stats: bool,
|
||||
) {
|
||||
if let Some(domain_entry) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
|
||||
if increment_stats {
|
||||
domain_entry.hits += 1
|
||||
}
|
||||
|
||||
domain_entry.fill_queryresult(qtype, result_vec);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lookup(&mut self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
match self.get_cache_state(qname, qtype) {
|
||||
CacheState::PositiveCache => {
|
||||
let mut qr = DnsPacket::new();
|
||||
self.fill_queryresult(qname, qtype, &mut qr.answers, true);
|
||||
self.fill_queryresult(qname, QueryType::NS, &mut qr.authorities, false);
|
||||
|
||||
Some(qr)
|
||||
}
|
||||
CacheState::NegativeCache => {
|
||||
let mut qr = DnsPacket::new();
|
||||
qr.header.rescode = ResultCode::NXDOMAIN;
|
||||
|
||||
Some(qr)
|
||||
}
|
||||
CacheState::NotCached => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store(&mut self, records: &[DnsRecord]) {
|
||||
for rec in records {
|
||||
let domain = match rec.get_domain() {
|
||||
Some(x) => x,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Some(ref mut rs) = self.domain_entries.get_mut(&domain).and_then(Arc::get_mut) {
|
||||
rs.store_record(rec);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut rs = DomainEntry::new(domain.clone());
|
||||
rs.store_record(rec);
|
||||
self.domain_entries.insert(domain.clone(), Arc::new(rs));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) {
|
||||
if let Some(ref mut rs) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
|
||||
rs.store_nxdomain(qtype, ttl);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut rs = DomainEntry::new(qname.to_string());
|
||||
rs.store_nxdomain(qtype, ttl);
|
||||
self.domain_entries.insert(qname.to_string(), Arc::new(rs));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SynchronizedCache {
|
||||
pub cache: RwLock<Cache>,
|
||||
}
|
||||
|
||||
impl SynchronizedCache {
|
||||
pub fn new() -> SynchronizedCache {
|
||||
SynchronizedCache {
|
||||
cache: RwLock::new(Cache::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list(&self) -> Result<Vec<Arc<DomainEntry>>> {
|
||||
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
|
||||
|
||||
let mut list = Vec::new();
|
||||
|
||||
for rs in cache.domain_entries.values() {
|
||||
list.push(rs.clone());
|
||||
}
|
||||
|
||||
Ok(list)
|
||||
}
|
||||
|
||||
pub fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
|
||||
let mut cache = match self.cache.write() {
|
||||
Ok(x) => x,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
cache.lookup(qname, qtype)
|
||||
}
|
||||
|
||||
pub fn store(&self, records: &[DnsRecord]) -> Result<()> {
|
||||
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
|
||||
|
||||
cache.store(records);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn store_nxdomain(&self, qname: &str, qtype: QueryType, ttl: u32) -> Result<()> {
|
||||
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
|
||||
|
||||
cache.store_nxdomain(qname, qtype, ttl);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::dns::protocol::{DnsRecord, QueryType, ResultCode, TransientTtl};
|
||||
|
||||
#[test]
|
||||
fn test_cache() {
|
||||
let mut cache = Cache::new();
|
||||
|
||||
// Verify that no data is returned when nothing is present
|
||||
if cache.lookup("www.google.com", QueryType::A).is_some() {
|
||||
panic!()
|
||||
}
|
||||
|
||||
// Register a negative cache entry
|
||||
cache.store_nxdomain("www.google.com", QueryType::A, 3600);
|
||||
|
||||
// Verify that we get a response, with the NXDOMAIN flag set
|
||||
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
|
||||
assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode);
|
||||
}
|
||||
|
||||
// Register a negative cache entry with no TTL
|
||||
cache.store_nxdomain("www.yahoo.com", QueryType::A, 0);
|
||||
|
||||
// And check that no such result is actually returned, since it's expired
|
||||
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
|
||||
panic!()
|
||||
}
|
||||
|
||||
// Now add some actual records
|
||||
let mut records = Vec::new();
|
||||
records.push(DnsRecord::A {
|
||||
domain: "www.google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
records.push(DnsRecord::A {
|
||||
domain: "www.yahoo.com".to_string(),
|
||||
addr: "127.0.0.2".parse().unwrap(),
|
||||
ttl: TransientTtl(0),
|
||||
});
|
||||
records.push(DnsRecord::CNAME {
|
||||
domain: "www.microsoft.com".to_string(),
|
||||
host: "www.somecdn.com".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
cache.store(&records);
|
||||
|
||||
// Test for successful lookup
|
||||
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
|
||||
assert_eq!(records[0], packet.answers[0]);
|
||||
} else {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Test for failed lookup, since no CNAME's are known for this domain
|
||||
if cache.lookup("www.google.com", QueryType::CNAME).is_some() {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Check for successful CNAME lookup
|
||||
if let Some(packet) = cache.lookup("www.microsoft.com", QueryType::CNAME) {
|
||||
assert_eq!(records[2], packet.answers[0]);
|
||||
} else {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// This lookup should fail, since it has expired due to the 0 second TTL
|
||||
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
|
||||
panic!();
|
||||
}
|
||||
|
||||
let mut records2 = Vec::new();
|
||||
records2.push(DnsRecord::A {
|
||||
domain: "www.yahoo.com".to_string(),
|
||||
addr: "127.0.0.2".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
cache.store(&records2);
|
||||
|
||||
// And now it should succeed, since the record has been store
|
||||
if !cache.lookup("www.yahoo.com", QueryType::A).is_some() {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Check stat counter behavior
|
||||
assert_eq!(3, cache.domain_entries.len());
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.google.com".to_string())
|
||||
.unwrap()
|
||||
.hits
|
||||
);
|
||||
assert_eq!(
|
||||
2,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.google.com".to_string())
|
||||
.unwrap()
|
||||
.updates
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.yahoo.com".to_string())
|
||||
.unwrap()
|
||||
.hits
|
||||
);
|
||||
assert_eq!(
|
||||
3,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.yahoo.com".to_string())
|
||||
.unwrap()
|
||||
.updates
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.microsoft.com".to_string())
|
||||
.unwrap()
|
||||
.updates
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
cache
|
||||
.domain_entries
|
||||
.get(&"www.microsoft.com".to_string())
|
||||
.unwrap()
|
||||
.hits
|
||||
);
|
||||
}
|
||||
}
|
@ -0,0 +1,400 @@
|
||||
//! client for sending DNS queries to other servers
|
||||
|
||||
use std::io::Write;
|
||||
use std::marker::{Send, Sync};
|
||||
use std::net::{TcpStream, UdpSocket};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::{channel, Sender};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread::{sleep, Builder};
|
||||
use std::time::Duration as SleepDuration;
|
||||
|
||||
use chrono::*;
|
||||
use derive_more::{Display, Error, From};
|
||||
|
||||
use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer};
|
||||
use crate::dns::netutil::{read_packet_length, write_packet_length};
|
||||
use crate::dns::protocol::{DnsPacket, DnsQuestion, QueryType};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum ClientError {
|
||||
Protocol(crate::dns::protocol::ProtocolError),
|
||||
Io(std::io::Error),
|
||||
PoisonedLock,
|
||||
LookupFailed,
|
||||
TimeOut,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
||||
pub trait DnsClient {
|
||||
fn get_sent_count(&self) -> usize;
|
||||
fn get_failed_count(&self) -> usize;
|
||||
|
||||
fn run(&self) -> Result<()>;
|
||||
fn send_query(
|
||||
&self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: (&str, u16),
|
||||
recursive: bool,
|
||||
) -> Result<DnsPacket>;
|
||||
}
|
||||
|
||||
/// The UDP client
|
||||
///
|
||||
/// This includes a fair bit of synchronization due to the stateless nature of UDP.
|
||||
/// When many queries are sent in parallell, the response packets can come back
|
||||
/// in any order. For that reason, we fire off replies on the sending thread, but
|
||||
/// handle replies on a single thread. A channel is created for every response,
|
||||
/// and the caller will block on the channel until the a response is received.
|
||||
pub struct DnsNetworkClient {
|
||||
total_sent: AtomicUsize,
|
||||
total_failed: AtomicUsize,
|
||||
|
||||
/// Counter for assigning packet ids
|
||||
seq: AtomicUsize,
|
||||
|
||||
/// The listener socket
|
||||
socket: UdpSocket,
|
||||
|
||||
/// Queries in progress
|
||||
pending_queries: Arc<Mutex<Vec<PendingQuery>>>,
|
||||
}
|
||||
|
||||
/// A query in progress. This struct holds the `id` if the request, and a channel
|
||||
/// endpoint for returning a response back to the thread from which the query
|
||||
/// was posed.
|
||||
struct PendingQuery {
|
||||
seq: u16,
|
||||
timestamp: DateTime<Local>,
|
||||
tx: Sender<Option<DnsPacket>>,
|
||||
}
|
||||
|
||||
unsafe impl Send for DnsNetworkClient {}
|
||||
unsafe impl Sync for DnsNetworkClient {}
|
||||
|
||||
impl DnsNetworkClient {
|
||||
pub fn new(port: u16) -> DnsNetworkClient {
|
||||
DnsNetworkClient {
|
||||
total_sent: AtomicUsize::new(0),
|
||||
total_failed: AtomicUsize::new(0),
|
||||
seq: AtomicUsize::new(0),
|
||||
socket: UdpSocket::bind(("0.0.0.0", port)).unwrap(),
|
||||
pending_queries: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a DNS query using TCP transport
|
||||
///
|
||||
/// This is much simpler than using UDP, since the kernel will take care of
|
||||
/// packet ordering, connection state, timeouts etc.
|
||||
pub fn send_tcp_query(
|
||||
&self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: (&str, u16),
|
||||
recursive: bool,
|
||||
) -> Result<DnsPacket> {
|
||||
let _ = self.total_sent.fetch_add(1, Ordering::Release);
|
||||
|
||||
// Prepare request
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
packet.header.id = self.seq.fetch_add(1, Ordering::SeqCst) as u16;
|
||||
if packet.header.id + 1 == 0xFFFF {
|
||||
self.seq.compare_and_swap(0xFFFF, 0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
packet.header.questions = 1;
|
||||
packet.header.recursion_desired = recursive;
|
||||
|
||||
packet.questions.push(DnsQuestion::new(qname.into(), qtype));
|
||||
|
||||
// Send query
|
||||
let mut req_buffer = BytePacketBuffer::new();
|
||||
packet.write(&mut req_buffer, 0xFFFF)?;
|
||||
|
||||
let mut socket = TcpStream::connect(server)?;
|
||||
|
||||
write_packet_length(&mut socket, req_buffer.pos())?;
|
||||
socket.write(&req_buffer.buf[0..req_buffer.pos])?;
|
||||
socket.flush()?;
|
||||
|
||||
let _ = read_packet_length(&mut socket)?;
|
||||
|
||||
let mut stream_buffer = StreamPacketBuffer::new(&mut socket);
|
||||
let packet = DnsPacket::from_buffer(&mut stream_buffer)?;
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
|
||||
/// Send a DNS query using UDP transport
|
||||
///
|
||||
/// This will construct a query packet, and fire it off to the specified server.
|
||||
/// The query is sent from the callee thread, but responses are read on a
|
||||
/// worker thread, and returned to this thread through a channel. Thus this
|
||||
/// method is thread safe, and can be used from any number of threads in
|
||||
/// parallell.
|
||||
pub fn send_udp_query(
|
||||
&self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: (&str, u16),
|
||||
recursive: bool,
|
||||
) -> Result<DnsPacket> {
|
||||
let _ = self.total_sent.fetch_add(1, Ordering::Release);
|
||||
|
||||
// Prepare request
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
packet.header.id = self.seq.fetch_add(1, Ordering::SeqCst) as u16;
|
||||
if packet.header.id + 1 == 0xFFFF {
|
||||
self.seq.compare_and_swap(0xFFFF, 0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
packet.header.questions = 1;
|
||||
packet.header.recursion_desired = recursive;
|
||||
|
||||
packet
|
||||
.questions
|
||||
.push(DnsQuestion::new(qname.to_string(), qtype));
|
||||
|
||||
// Create a return channel, and add a `PendingQuery` to the list of lookups
|
||||
// in progress
|
||||
let (tx, rx) = channel();
|
||||
{
|
||||
let mut pending_queries = self
|
||||
.pending_queries
|
||||
.lock()
|
||||
.map_err(|_| ClientError::PoisonedLock)?;
|
||||
pending_queries.push(PendingQuery {
|
||||
seq: packet.header.id,
|
||||
timestamp: Local::now(),
|
||||
tx: tx,
|
||||
});
|
||||
}
|
||||
|
||||
// Send query
|
||||
let mut req_buffer = BytePacketBuffer::new();
|
||||
packet.write(&mut req_buffer, 512)?;
|
||||
self.socket
|
||||
.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
|
||||
|
||||
// Wait for response
|
||||
match rx.recv() {
|
||||
Ok(Some(qr)) => Ok(qr),
|
||||
Ok(None) => {
|
||||
let _ = self.total_failed.fetch_add(1, Ordering::Release);
|
||||
Err(ClientError::TimeOut)
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = self.total_failed.fetch_add(1, Ordering::Release);
|
||||
Err(ClientError::LookupFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DnsClient for DnsNetworkClient {
|
||||
fn get_sent_count(&self) -> usize {
|
||||
self.total_sent.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
fn get_failed_count(&self) -> usize {
|
||||
self.total_failed.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// The run method launches a worker thread. Unless this thread is running, no
|
||||
/// responses will ever be generated, and clients will just block indefinitely.
|
||||
fn run(&self) -> Result<()> {
|
||||
// Start the thread for handling incoming responses
|
||||
{
|
||||
let socket_copy = self.socket.try_clone()?;
|
||||
let pending_queries_lock = self.pending_queries.clone();
|
||||
|
||||
Builder::new()
|
||||
.name("DnsNetworkClient-worker-thread".into())
|
||||
.spawn(move || {
|
||||
loop {
|
||||
// Read data into a buffer
|
||||
let mut res_buffer = BytePacketBuffer::new();
|
||||
match socket_copy.recv_from(&mut res_buffer.buf) {
|
||||
Ok(_) => {}
|
||||
Err(_) => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Construct a DnsPacket from buffer, skipping the packet if parsing
|
||||
// failed
|
||||
let packet = match DnsPacket::from_buffer(&mut res_buffer) {
|
||||
Ok(packet) => packet,
|
||||
Err(err) => {
|
||||
println!(
|
||||
"DnsNetworkClient failed to parse packet with error: {}",
|
||||
err
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Acquire a lock on the pending_queries list, and search for a
|
||||
// matching PendingQuery to which to deliver the response.
|
||||
if let Ok(mut pending_queries) = pending_queries_lock.lock() {
|
||||
let mut matched_query = None;
|
||||
for (i, pending_query) in pending_queries.iter().enumerate() {
|
||||
if pending_query.seq == packet.header.id {
|
||||
// Matching query found, send the response
|
||||
let _ = pending_query.tx.send(Some(packet.clone()));
|
||||
|
||||
// Mark this index for removal from list
|
||||
matched_query = Some(i);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(idx) = matched_query {
|
||||
pending_queries.remove(idx);
|
||||
} else {
|
||||
println!("Discarding response for: {:?}", packet.questions[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
// Start the thread for timing out requests
|
||||
{
|
||||
let pending_queries_lock = self.pending_queries.clone();
|
||||
|
||||
Builder::new()
|
||||
.name("DnsNetworkClient-timeout-thread".into())
|
||||
.spawn(move || {
|
||||
let timeout = Duration::seconds(1);
|
||||
loop {
|
||||
if let Ok(mut pending_queries) = pending_queries_lock.lock() {
|
||||
let mut finished_queries = Vec::new();
|
||||
for (i, pending_query) in pending_queries.iter().enumerate() {
|
||||
let expires = pending_query.timestamp + timeout;
|
||||
if expires < Local::now() {
|
||||
let _ = pending_query.tx.send(None);
|
||||
finished_queries.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove `PendingQuery` objects from the list, in reverse order
|
||||
for idx in finished_queries.iter().rev() {
|
||||
pending_queries.remove(*idx);
|
||||
}
|
||||
}
|
||||
|
||||
sleep(SleepDuration::from_millis(100));
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_query(
|
||||
&self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: (&str, u16),
|
||||
recursive: bool,
|
||||
) -> Result<DnsPacket> {
|
||||
let packet = self.send_udp_query(qname, qtype, server, recursive)?;
|
||||
if !packet.header.truncated_message {
|
||||
return Ok(packet);
|
||||
}
|
||||
|
||||
println!("Truncated response - resending as TCP");
|
||||
self.send_tcp_query(qname, qtype, server, recursive)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType};
|
||||
|
||||
pub type StubCallback = dyn Fn(&str, QueryType, (&str, u16), bool) -> Result<DnsPacket>;
|
||||
|
||||
pub struct DnsStubClient {
|
||||
callback: Box<StubCallback>,
|
||||
}
|
||||
|
||||
impl<'a> DnsStubClient {
|
||||
pub fn new(callback: Box<StubCallback>) -> DnsStubClient {
|
||||
DnsStubClient { callback: callback }
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for DnsStubClient {}
|
||||
unsafe impl Sync for DnsStubClient {}
|
||||
|
||||
impl DnsClient for DnsStubClient {
|
||||
fn get_sent_count(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn get_failed_count(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn run(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_query(
|
||||
&self,
|
||||
qname: &str,
|
||||
qtype: QueryType,
|
||||
server: (&str, u16),
|
||||
recursive: bool,
|
||||
) -> Result<DnsPacket> {
|
||||
(self.callback)(qname, qtype, server, recursive)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_udp_client() {
|
||||
let client = DnsNetworkClient::new(31456);
|
||||
client.run().unwrap();
|
||||
|
||||
let res = client
|
||||
.send_udp_query("google.com", QueryType::A, ("8.8.8.8", 53), true)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.questions[0].name, "google.com");
|
||||
assert!(res.answers.len() > 0);
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("google.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_tcp_client() {
|
||||
let client = DnsNetworkClient::new(31457);
|
||||
let res = client
|
||||
.send_tcp_query("google.com", QueryType::A, ("8.8.8.8", 53), true)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.questions[0].name, "google.com");
|
||||
assert!(res.answers.len() > 0);
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("google.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,140 @@
|
||||
//! The `ServerContext in this thread holds the common state across the server
|
||||
|
||||
use std::fs;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use derive_more::{Display, Error, From};
|
||||
|
||||
use crate::dns::authority::Authority;
|
||||
use crate::dns::cache::SynchronizedCache;
|
||||
use crate::dns::client::{DnsClient, DnsNetworkClient};
|
||||
use crate::dns::resolve::{DnsResolver, ForwardingDnsResolver, RecursiveDnsResolver};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum ContextError {
|
||||
Authority(crate::dns::authority::AuthorityError),
|
||||
Client(crate::dns::client::ClientError),
|
||||
Io(std::io::Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, ContextError>;
|
||||
|
||||
pub struct ServerStatistics {
|
||||
pub tcp_query_count: AtomicUsize,
|
||||
pub udp_query_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl ServerStatistics {
|
||||
pub fn get_tcp_query_count(&self) -> usize {
|
||||
self.tcp_query_count.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
pub fn get_udp_query_count(&self) -> usize {
|
||||
self.udp_query_count.load(Ordering::Acquire)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ResolveStrategy {
|
||||
Recursive,
|
||||
Forward { host: String, port: u16 },
|
||||
}
|
||||
|
||||
pub struct ServerContext {
|
||||
pub authority: Authority,
|
||||
pub cache: SynchronizedCache,
|
||||
pub client: Box<dyn DnsClient + Sync + Send>,
|
||||
pub dns_port: u16,
|
||||
pub api_port: u16,
|
||||
pub resolve_strategy: ResolveStrategy,
|
||||
pub allow_recursive: bool,
|
||||
pub enable_udp: bool,
|
||||
pub enable_tcp: bool,
|
||||
pub enable_api: bool,
|
||||
pub statistics: ServerStatistics,
|
||||
pub zones_dir: &'static str
|
||||
}
|
||||
|
||||
impl Default for ServerContext {
|
||||
fn default() -> Self {
|
||||
ServerContext::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerContext {
|
||||
pub fn new() -> ServerContext {
|
||||
ServerContext {
|
||||
authority: Authority::new(),
|
||||
cache: SynchronizedCache::new(),
|
||||
client: Box::new(DnsNetworkClient::new(34255)),
|
||||
dns_port: 53,
|
||||
api_port: 5380,
|
||||
resolve_strategy: ResolveStrategy::Recursive,
|
||||
allow_recursive: true,
|
||||
enable_udp: true,
|
||||
enable_tcp: true,
|
||||
enable_api: true,
|
||||
statistics: ServerStatistics {
|
||||
tcp_query_count: AtomicUsize::new(0),
|
||||
udp_query_count: AtomicUsize::new(0),
|
||||
},
|
||||
zones_dir: "zones",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn initialize(&mut self) -> Result<()> {
|
||||
// Create zones directory if it doesn't exist
|
||||
fs::create_dir_all(self.zones_dir)?;
|
||||
|
||||
// Start UDP client thread
|
||||
self.client.run()?;
|
||||
|
||||
// Load authority data
|
||||
self.authority.load()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_resolver(&self, ptr: Arc<ServerContext>) -> Box<dyn DnsResolver> {
|
||||
match self.resolve_strategy {
|
||||
ResolveStrategy::Recursive => Box::new(RecursiveDnsResolver::new(ptr)),
|
||||
ResolveStrategy::Forward { ref host, port } => {
|
||||
Box::new(ForwardingDnsResolver::new(ptr, (host.clone(), port)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::dns::authority::Authority;
|
||||
use crate::dns::cache::SynchronizedCache;
|
||||
|
||||
use crate::dns::client::tests::{DnsStubClient, StubCallback};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn create_test_context(callback: Box<StubCallback>) -> Arc<ServerContext> {
|
||||
Arc::new(ServerContext {
|
||||
authority: Authority::new(),
|
||||
cache: SynchronizedCache::new(),
|
||||
client: Box::new(DnsStubClient::new(callback)),
|
||||
dns_port: 53,
|
||||
api_port: 5380,
|
||||
resolve_strategy: ResolveStrategy::Recursive,
|
||||
allow_recursive: true,
|
||||
enable_udp: true,
|
||||
enable_tcp: true,
|
||||
enable_api: true,
|
||||
statistics: ServerStatistics {
|
||||
tcp_query_count: AtomicUsize::new(0),
|
||||
udp_query_count: AtomicUsize::new(0),
|
||||
},
|
||||
zones_dir: "zones",
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
/*
|
||||
Copyright 2018 Emil Hernvall
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||
of the Software, and to permit persons to whom the Software is furnished to do
|
||||
so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
*/
|
||||
|
||||
//! The dns module implements the DNS protocol and the related functions
|
||||
|
||||
pub mod authority;
|
||||
pub mod buffer;
|
||||
pub mod cache;
|
||||
pub mod client;
|
||||
pub mod context;
|
||||
pub mod protocol;
|
||||
pub mod resolve;
|
||||
pub mod server;
|
||||
|
||||
mod netutil;
|
@ -0,0 +1,19 @@
|
||||
use std::io::{Read, Result, Write};
|
||||
use std::net::TcpStream;
|
||||
|
||||
pub fn read_packet_length(stream: &mut TcpStream) -> Result<u16> {
|
||||
let mut len_buffer = [0; 2];
|
||||
stream.read(&mut len_buffer)?;
|
||||
|
||||
Ok(((len_buffer[0] as u16) << 8) | (len_buffer[1] as u16))
|
||||
}
|
||||
|
||||
pub fn write_packet_length(stream: &mut TcpStream, len: usize) -> Result<()> {
|
||||
let mut len_buffer = [0; 2];
|
||||
len_buffer[0] = (len >> 8) as u8;
|
||||
len_buffer[1] = (len & 0xFF) as u8;
|
||||
|
||||
stream.write(&len_buffer)?;
|
||||
|
||||
Ok(())
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,569 @@
|
||||
//! resolver implementations implementing different strategies for answering
|
||||
//! incoming queries
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::vec::Vec;
|
||||
|
||||
use derive_more::{Display, Error, From};
|
||||
|
||||
use crate::dns::context::ServerContext;
|
||||
use crate::dns::protocol::{DnsPacket, QueryType, ResultCode};
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum ResolveError {
|
||||
Client(crate::dns::client::ClientError),
|
||||
Cache(crate::dns::cache::CacheError),
|
||||
Io(std::io::Error),
|
||||
NoServerFound,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, ResolveError>;
|
||||
|
||||
pub trait DnsResolver {
|
||||
fn get_context(&self) -> Arc<ServerContext>;
|
||||
|
||||
fn resolve(&mut self, qname: &str, qtype: QueryType, recursive: bool) -> Result<DnsPacket> {
|
||||
if let QueryType::UNKNOWN(_) = qtype {
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.rescode = ResultCode::NOTIMP;
|
||||
return Ok(packet);
|
||||
}
|
||||
|
||||
let context = self.get_context();
|
||||
|
||||
if let Some(qr) = context.authority.query(qname, qtype) {
|
||||
return Ok(qr);
|
||||
}
|
||||
|
||||
if !recursive || !context.allow_recursive {
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.rescode = ResultCode::REFUSED;
|
||||
return Ok(packet);
|
||||
}
|
||||
|
||||
if let Some(qr) = context.cache.lookup(qname, qtype) {
|
||||
return Ok(qr);
|
||||
}
|
||||
|
||||
if qtype == QueryType::A || qtype == QueryType::AAAA {
|
||||
if let Some(qr) = context.cache.lookup(qname, QueryType::CNAME) {
|
||||
return Ok(qr);
|
||||
}
|
||||
}
|
||||
|
||||
self.perform(qname, qtype)
|
||||
}
|
||||
|
||||
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket>;
|
||||
}
|
||||
|
||||
/// A Forwarding DNS Resolver
|
||||
///
|
||||
/// This resolver uses an external DNS server to service a query
|
||||
pub struct ForwardingDnsResolver {
|
||||
context: Arc<ServerContext>,
|
||||
server: (String, u16),
|
||||
}
|
||||
|
||||
impl ForwardingDnsResolver {
|
||||
pub fn new(context: Arc<ServerContext>, server: (String, u16)) -> ForwardingDnsResolver {
|
||||
ForwardingDnsResolver {
|
||||
context: context,
|
||||
server: server,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DnsResolver for ForwardingDnsResolver {
|
||||
fn get_context(&self) -> Arc<ServerContext> {
|
||||
self.context.clone()
|
||||
}
|
||||
|
||||
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket> {
|
||||
let &(ref host, port) = &self.server;
|
||||
let result = self
|
||||
.context
|
||||
.client
|
||||
.send_query(qname, qtype, (host.as_str(), port), true)?;
|
||||
|
||||
self.context.cache.store(&result.answers)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Recursive DNS resolver
|
||||
///
|
||||
/// This resolver can answer any request using the root servers of the internet
|
||||
pub struct RecursiveDnsResolver {
|
||||
context: Arc<ServerContext>,
|
||||
}
|
||||
|
||||
impl RecursiveDnsResolver {
|
||||
pub fn new(context: Arc<ServerContext>) -> RecursiveDnsResolver {
|
||||
RecursiveDnsResolver { context: context }
|
||||
}
|
||||
}
|
||||
|
||||
impl DnsResolver for RecursiveDnsResolver {
|
||||
fn get_context(&self) -> Arc<ServerContext> {
|
||||
self.context.clone()
|
||||
}
|
||||
|
||||
fn perform(&mut self, qname: &str, qtype: QueryType) -> Result<DnsPacket> {
|
||||
// Find the closest name server by splitting the label and progessively
|
||||
// moving towards the root servers. I.e. check "google.com", then "com",
|
||||
// and finally "".
|
||||
let mut tentative_ns = None;
|
||||
|
||||
let labels = qname.split('.').collect::<Vec<&str>>();
|
||||
for lbl_idx in 0..labels.len() + 1 {
|
||||
let domain = labels[lbl_idx..].join(".");
|
||||
|
||||
match self
|
||||
.context
|
||||
.cache
|
||||
.lookup(&domain, QueryType::NS)
|
||||
.and_then(|qr| qr.get_unresolved_ns(&domain))
|
||||
.and_then(|ns| self.context.cache.lookup(&ns, QueryType::A))
|
||||
.and_then(|qr| qr.get_random_a())
|
||||
{
|
||||
Some(addr) => {
|
||||
tentative_ns = Some(addr);
|
||||
break;
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
let mut ns = tentative_ns.ok_or_else(|| ResolveError::NoServerFound)?;
|
||||
|
||||
// Start querying name servers
|
||||
loop {
|
||||
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
|
||||
|
||||
let ns_copy = ns.clone();
|
||||
|
||||
let server = (ns_copy.as_str(), 53);
|
||||
let response = self
|
||||
.context
|
||||
.client
|
||||
.send_query(qname, qtype.clone(), server, false)?;
|
||||
|
||||
// If we've got an actual answer, we're done!
|
||||
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
|
||||
let _ = self.context.cache.store(&response.answers);
|
||||
let _ = self.context.cache.store(&response.authorities);
|
||||
let _ = self.context.cache.store(&response.resources);
|
||||
return Ok(response.clone());
|
||||
}
|
||||
|
||||
if response.header.rescode == ResultCode::NXDOMAIN {
|
||||
if let Some(ttl) = response.get_ttl_from_soa() {
|
||||
let _ = self.context.cache.store_nxdomain(qname, qtype, ttl);
|
||||
}
|
||||
return Ok(response.clone());
|
||||
}
|
||||
|
||||
// Otherwise, try to find a new nameserver based on NS and a
|
||||
// corresponding A record in the additional section
|
||||
if let Some(new_ns) = response.get_resolved_ns(qname) {
|
||||
// If there is such a record, we can retry the loop with that NS
|
||||
ns = new_ns.clone();
|
||||
let _ = self.context.cache.store(&response.answers);
|
||||
let _ = self.context.cache.store(&response.authorities);
|
||||
let _ = self.context.cache.store(&response.resources);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// If not, we'll have to resolve the ip of a NS record
|
||||
let new_ns_name = match response.get_unresolved_ns(qname) {
|
||||
Some(x) => x,
|
||||
None => return Ok(response.clone()),
|
||||
};
|
||||
|
||||
// Recursively resolve the NS
|
||||
let recursive_response = self.resolve(&new_ns_name, QueryType::A, true)?;
|
||||
|
||||
// Pick a random IP and restart
|
||||
if let Some(new_ns) = recursive_response.get_random_a() {
|
||||
ns = new_ns.clone();
|
||||
} else {
|
||||
return Ok(response.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode, TransientTtl};
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::dns::context::tests::create_test_context;
|
||||
use crate::dns::context::ResolveStrategy;
|
||||
|
||||
#[test]
|
||||
fn test_forwarding_resolver() {
|
||||
let mut context = create_test_context(Box::new(|qname, _, _, _| {
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
if qname == "google.com" {
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
} else {
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
}));
|
||||
|
||||
match Arc::get_mut(&mut context) {
|
||||
Some(mut ctx) => {
|
||||
ctx.resolve_strategy = ResolveStrategy::Forward {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 53,
|
||||
};
|
||||
}
|
||||
None => panic!(),
|
||||
}
|
||||
|
||||
let mut resolver = context.create_resolver(context.clone());
|
||||
|
||||
// First verify that we get a match back
|
||||
{
|
||||
let res = match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(1, res.answers.len());
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("google.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
};
|
||||
|
||||
// Do the same lookup again, and verify that it's present in the cache
|
||||
// and that the counter has been updated
|
||||
{
|
||||
let res = match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(1, res.answers.len());
|
||||
|
||||
let list = match context.cache.list() {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(1, list.len());
|
||||
|
||||
assert_eq!("google.com", list[0].domain);
|
||||
assert_eq!(1, list[0].record_types.len());
|
||||
assert_eq!(1, list[0].hits);
|
||||
};
|
||||
|
||||
// Do a failed lookup
|
||||
{
|
||||
let res = match resolver.resolve("yahoo.com", QueryType::A, true) {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(0, res.answers.len());
|
||||
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_resolver_with_no_nameserver() {
|
||||
let context = create_test_context(Box::new(|_, _, _, _| {
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
Ok(packet)
|
||||
}));
|
||||
|
||||
let mut resolver = context.create_resolver(context.clone());
|
||||
|
||||
// Expect failure when no name servers are available
|
||||
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_resolver_with_missing_a_record() {
|
||||
let context = create_test_context(Box::new(|_, _, _, _| {
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
Ok(packet)
|
||||
}));
|
||||
|
||||
let mut resolver = context.create_resolver(context.clone());
|
||||
|
||||
// Expect failure when no name servers are available
|
||||
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Insert name server, but no corresponding A record
|
||||
let mut nameservers = Vec::new();
|
||||
nameservers.push(DnsRecord::NS {
|
||||
domain: "".to_string(),
|
||||
host: "a.myroot.net".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
let _ = context.cache.store(&nameservers);
|
||||
|
||||
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_resolver_match_order() {
|
||||
let context = create_test_context(Box::new(|_, _, (server, _), _| {
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
if server == "127.0.0.1" {
|
||||
packet.header.id = 1;
|
||||
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "a.google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
return Ok(packet);
|
||||
} else if server == "127.0.0.2" {
|
||||
packet.header.id = 2;
|
||||
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "b.google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
return Ok(packet);
|
||||
} else if server == "127.0.0.3" {
|
||||
packet.header.id = 3;
|
||||
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "c.google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
return Ok(packet);
|
||||
}
|
||||
|
||||
packet.header.id = 999;
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
Ok(packet)
|
||||
}));
|
||||
|
||||
let mut resolver = context.create_resolver(context.clone());
|
||||
|
||||
// Expect failure when no name servers are available
|
||||
if let Ok(_) = resolver.resolve("google.com", QueryType::A, true) {
|
||||
panic!();
|
||||
}
|
||||
|
||||
// Insert root servers
|
||||
{
|
||||
let mut nameservers = Vec::new();
|
||||
nameservers.push(DnsRecord::NS {
|
||||
domain: "".to_string(),
|
||||
host: "a.myroot.net".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
nameservers.push(DnsRecord::A {
|
||||
domain: "a.myroot.net".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
let _ = context.cache.store(&nameservers);
|
||||
}
|
||||
|
||||
match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(packet) => {
|
||||
assert_eq!(1, packet.header.id);
|
||||
}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
// Insert TLD servers
|
||||
{
|
||||
let mut nameservers = Vec::new();
|
||||
nameservers.push(DnsRecord::NS {
|
||||
domain: "com".to_string(),
|
||||
host: "a.mytld.net".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
nameservers.push(DnsRecord::A {
|
||||
domain: "a.mytld.net".to_string(),
|
||||
addr: "127.0.0.2".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
let _ = context.cache.store(&nameservers);
|
||||
}
|
||||
|
||||
match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(packet) => {
|
||||
assert_eq!(2, packet.header.id);
|
||||
}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
|
||||
// Insert authoritative servers
|
||||
{
|
||||
let mut nameservers = Vec::new();
|
||||
nameservers.push(DnsRecord::NS {
|
||||
domain: "google.com".to_string(),
|
||||
host: "ns1.google.com".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
nameservers.push(DnsRecord::A {
|
||||
domain: "ns1.google.com".to_string(),
|
||||
addr: "127.0.0.3".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
let _ = context.cache.store(&nameservers);
|
||||
}
|
||||
|
||||
match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(packet) => {
|
||||
assert_eq!(3, packet.header.id);
|
||||
}
|
||||
Err(_) => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_resolver_successfully() {
|
||||
let context = create_test_context(Box::new(|qname, _, _, _| {
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
if qname == "google.com" {
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
} else {
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
|
||||
packet.authorities.push(DnsRecord::SOA {
|
||||
domain: "google.com".to_string(),
|
||||
r_name: "google.com".to_string(),
|
||||
m_name: "google.com".to_string(),
|
||||
serial: 0,
|
||||
refresh: 3600,
|
||||
retry: 3600,
|
||||
expire: 3600,
|
||||
minimum: 3600,
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
}));
|
||||
|
||||
let mut resolver = context.create_resolver(context.clone());
|
||||
|
||||
// Insert name servers
|
||||
let mut nameservers = Vec::new();
|
||||
nameservers.push(DnsRecord::NS {
|
||||
domain: "google.com".to_string(),
|
||||
host: "ns1.google.com".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
nameservers.push(DnsRecord::A {
|
||||
domain: "ns1.google.com".to_string(),
|
||||
addr: "127.0.0.1".parse().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
|
||||
let _ = context.cache.store(&nameservers);
|
||||
|
||||
// Check that we can successfully resolve
|
||||
{
|
||||
let res = match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(1, res.answers.len());
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("google.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
};
|
||||
|
||||
// And that we won't find anything for a domain that isn't present
|
||||
{
|
||||
let res = match resolver.resolve("foobar.google.com", QueryType::A, true) {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
|
||||
assert_eq!(0, res.answers.len());
|
||||
};
|
||||
|
||||
// Perform another successful query, that should hit the cache
|
||||
{
|
||||
let res = match resolver.resolve("google.com", QueryType::A, true) {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(1, res.answers.len());
|
||||
};
|
||||
|
||||
// Now check that the cache is used, and that the statistics is correct
|
||||
{
|
||||
let list = match context.cache.list() {
|
||||
Ok(x) => x,
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
|
||||
assert_eq!(3, list.len());
|
||||
|
||||
// Check statistics for google entry
|
||||
assert_eq!("google.com", list[1].domain);
|
||||
|
||||
// Should have a NS record and an A record for a total of 2 record types
|
||||
assert_eq!(2, list[1].record_types.len());
|
||||
|
||||
// Should have been hit two times for NS google.com and once for
|
||||
// A google.com
|
||||
assert_eq!(3, list[1].hits);
|
||||
|
||||
assert_eq!("ns1.google.com", list[2].domain);
|
||||
assert_eq!(1, list[2].record_types.len());
|
||||
assert_eq!(2, list[2].hits);
|
||||
};
|
||||
}
|
||||
}
|
@ -0,0 +1,608 @@
|
||||
//! UDP and TCP server implementations for DNS
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{Shutdown, TcpListener, TcpStream, UdpSocket};
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::mpsc::{channel, Sender};
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
use std::thread::Builder;
|
||||
|
||||
use derive_more::{Display, Error, From};
|
||||
use rand::random;
|
||||
|
||||
use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer, VectorPacketBuffer};
|
||||
use crate::dns::context::ServerContext;
|
||||
use crate::dns::netutil::{read_packet_length, write_packet_length};
|
||||
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
|
||||
use crate::dns::resolve::DnsResolver;
|
||||
|
||||
#[derive(Debug, Display, From, Error)]
|
||||
pub enum ServerError {
|
||||
Io(std::io::Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, ServerError>;
|
||||
|
||||
macro_rules! return_or_report {
|
||||
( $x:expr, $message:expr ) => {
|
||||
match $x {
|
||||
Ok(res) => res,
|
||||
Err(_) => {
|
||||
println!($message);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! ignore_or_report {
|
||||
( $x:expr, $message:expr ) => {
|
||||
match $x {
|
||||
Ok(_) => {}
|
||||
Err(_) => {
|
||||
println!($message);
|
||||
return;
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/// Common trait for DNS servers
|
||||
pub trait DnsServer {
|
||||
/// Initialize the server and start listenening
|
||||
///
|
||||
/// This method should _NOT_ block. Rather, servers are expected to spawn a new
|
||||
/// thread to handle requests and return immediately.
|
||||
fn run_server(self) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Utility function for resolving domains referenced in for example CNAME or SRV
|
||||
/// records. This usually spares the client from having to perform additional
|
||||
/// lookups.
|
||||
fn resolve_cnames(
|
||||
lookup_list: &[DnsRecord],
|
||||
results: &mut Vec<DnsPacket>,
|
||||
resolver: &mut Box<dyn DnsResolver>,
|
||||
depth: u16,
|
||||
) {
|
||||
if depth > 10 {
|
||||
return;
|
||||
}
|
||||
|
||||
for ref rec in lookup_list {
|
||||
match **rec {
|
||||
DnsRecord::CNAME { ref host, .. } | DnsRecord::SRV { ref host, .. } => {
|
||||
if let Ok(result2) = resolver.resolve(host, QueryType::A, true) {
|
||||
let new_unmatched = result2.get_unresolved_cnames();
|
||||
results.push(result2);
|
||||
|
||||
resolve_cnames(&new_unmatched, results, resolver, depth + 1);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the actual work for a query
|
||||
///
|
||||
/// Incoming requests are validated to make sure they are well formed and adhere
|
||||
/// to the server configuration. If so, the request will be passed on to the
|
||||
/// active resolver and a query will be performed. It will also resolve some
|
||||
/// possible references within the query, such as CNAME hosts.
|
||||
///
|
||||
/// This function will always return a valid packet, even if the request could not
|
||||
/// be performed, since we still want to send something back to the client.
|
||||
pub fn execute_query(context: Arc<ServerContext>, request: &DnsPacket) -> DnsPacket {
|
||||
let mut packet = DnsPacket::new();
|
||||
packet.header.id = request.header.id;
|
||||
packet.header.recursion_available = context.allow_recursive;
|
||||
packet.header.response = true;
|
||||
|
||||
if request.header.recursion_desired && !context.allow_recursive {
|
||||
packet.header.rescode = ResultCode::REFUSED;
|
||||
} else if request.questions.is_empty() {
|
||||
packet.header.rescode = ResultCode::FORMERR;
|
||||
} else {
|
||||
let mut results = Vec::new();
|
||||
|
||||
let question = &request.questions[0];
|
||||
packet.questions.push(question.clone());
|
||||
|
||||
let mut resolver = context.create_resolver(context.clone());
|
||||
let rescode = match resolver.resolve(
|
||||
&question.name,
|
||||
question.qtype,
|
||||
request.header.recursion_desired,
|
||||
) {
|
||||
Ok(result) => {
|
||||
let rescode = result.header.rescode;
|
||||
|
||||
let unmatched = result.get_unresolved_cnames();
|
||||
results.push(result);
|
||||
|
||||
resolve_cnames(&unmatched, &mut results, &mut resolver, 0);
|
||||
|
||||
rescode
|
||||
}
|
||||
Err(err) => {
|
||||
println!(
|
||||
"Failed to resolve {:?} {}: {:?}",
|
||||
question.qtype, question.name, err
|
||||
);
|
||||
ResultCode::SERVFAIL
|
||||
}
|
||||
};
|
||||
|
||||
packet.header.rescode = rescode;
|
||||
|
||||
for result in results {
|
||||
for rec in result.answers {
|
||||
packet.answers.push(rec);
|
||||
}
|
||||
for rec in result.authorities {
|
||||
packet.authorities.push(rec);
|
||||
}
|
||||
for rec in result.resources {
|
||||
packet.resources.push(rec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
packet
|
||||
}
|
||||
|
||||
/// The UDP server
|
||||
///
|
||||
/// Accepts DNS queries through UDP, and uses the `ServerContext` to determine
|
||||
/// how to service the request. Packets are read on a single thread, after which
|
||||
/// a new thread is spawned to service the request asynchronously.
|
||||
pub struct DnsUdpServer {
|
||||
context: Arc<ServerContext>,
|
||||
request_queue: Arc<Mutex<VecDeque<(SocketAddr, DnsPacket)>>>,
|
||||
request_cond: Arc<Condvar>,
|
||||
thread_count: usize,
|
||||
}
|
||||
|
||||
impl DnsUdpServer {
|
||||
pub fn new(context: Arc<ServerContext>, thread_count: usize) -> DnsUdpServer {
|
||||
DnsUdpServer {
|
||||
context: context,
|
||||
request_queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
request_cond: Arc::new(Condvar::new()),
|
||||
thread_count: thread_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DnsServer for DnsUdpServer {
|
||||
/// Launch the server
|
||||
///
|
||||
/// This method takes ownership of the server, preventing the method from
|
||||
/// being called multiple times.
|
||||
fn run_server(self) -> Result<()> {
|
||||
// Bind the socket
|
||||
let socket = UdpSocket::bind(("0.0.0.0", self.context.dns_port))?;
|
||||
|
||||
// Spawn threads for handling requests
|
||||
for thread_id in 0..self.thread_count {
|
||||
let socket_clone = match socket.try_clone() {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
println!("Failed to clone socket when starting UDP server: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let context = self.context.clone();
|
||||
let request_cond = self.request_cond.clone();
|
||||
let request_queue = self.request_queue.clone();
|
||||
|
||||
let name = "DnsUdpServer-request-".to_string() + &thread_id.to_string();
|
||||
let _ = Builder::new().name(name).spawn(move || {
|
||||
loop {
|
||||
// Acquire lock, and wait on the condition until data is
|
||||
// available. Then proceed with popping an entry of the queue.
|
||||
let (src, request) = match request_queue
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|x| request_cond.wait(x).ok())
|
||||
.and_then(|mut x| x.pop_front())
|
||||
{
|
||||
Some(x) => x,
|
||||
None => {
|
||||
println!("Not expected to happen!");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut size_limit = 512;
|
||||
|
||||
// Check for EDNS
|
||||
if request.resources.len() == 1 {
|
||||
if let DnsRecord::OPT { packet_len, .. } = request.resources[0] {
|
||||
size_limit = packet_len as usize;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a response buffer, and ask the context for an appropriate
|
||||
// resolver
|
||||
let mut res_buffer = VectorPacketBuffer::new();
|
||||
|
||||
let mut packet = execute_query(context.clone(), &request);
|
||||
let _ = packet.write(&mut res_buffer, size_limit);
|
||||
|
||||
// Fire off the response
|
||||
let len = res_buffer.pos();
|
||||
let data = return_or_report!(
|
||||
res_buffer.get_range(0, len),
|
||||
"Failed to get buffer data"
|
||||
);
|
||||
ignore_or_report!(
|
||||
socket_clone.send_to(data, src),
|
||||
"Failed to send response packet"
|
||||
);
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
// Start servicing requests
|
||||
let _ = Builder::new()
|
||||
.name("DnsUdpServer-incoming".into())
|
||||
.spawn(move || {
|
||||
loop {
|
||||
let _ = self
|
||||
.context
|
||||
.statistics
|
||||
.udp_query_count
|
||||
.fetch_add(1, Ordering::Release);
|
||||
|
||||
// Read a query packet
|
||||
let mut req_buffer = BytePacketBuffer::new();
|
||||
let (_, src) = match socket.recv_from(&mut req_buffer.buf) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
println!("Failed to read from UDP socket: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse it
|
||||
let request = match DnsPacket::from_buffer(&mut req_buffer) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
println!("Failed to parse UDP query packet: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Acquire lock, add request to queue, and notify waiting threads
|
||||
// using the condition.
|
||||
match self.request_queue.lock() {
|
||||
Ok(mut queue) => {
|
||||
queue.push_back((src, request));
|
||||
self.request_cond.notify_one();
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Failed to send UDP request for processing: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// TCP DNS server
|
||||
pub struct DnsTcpServer {
|
||||
context: Arc<ServerContext>,
|
||||
senders: Vec<Sender<TcpStream>>,
|
||||
thread_count: usize,
|
||||
}
|
||||
|
||||
impl DnsTcpServer {
|
||||
pub fn new(context: Arc<ServerContext>, thread_count: usize) -> DnsTcpServer {
|
||||
DnsTcpServer {
|
||||
context: context,
|
||||
senders: Vec::new(),
|
||||
thread_count: thread_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DnsServer for DnsTcpServer {
|
||||
fn run_server(mut self) -> Result<()> {
|
||||
let socket = TcpListener::bind(("0.0.0.0", self.context.dns_port))?;
|
||||
|
||||
// Spawn threads for handling requests, and create the channels
|
||||
for thread_id in 0..self.thread_count {
|
||||
let (tx, rx) = channel();
|
||||
self.senders.push(tx);
|
||||
|
||||
let context = self.context.clone();
|
||||
|
||||
let name = "DnsTcpServer-request-".to_string() + &thread_id.to_string();
|
||||
let _ = Builder::new().name(name).spawn(move || {
|
||||
loop {
|
||||
let mut stream = match rx.recv() {
|
||||
Ok(x) => x,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let _ = context
|
||||
.statistics
|
||||
.tcp_query_count
|
||||
.fetch_add(1, Ordering::Release);
|
||||
|
||||
// When DNS packets are sent over TCP, they're prefixed with a two byte
|
||||
// length. We don't really need to know the length in advance, so we
|
||||
// just move past it and continue reading as usual
|
||||
ignore_or_report!(
|
||||
read_packet_length(&mut stream),
|
||||
"Failed to read query packet length"
|
||||
);
|
||||
|
||||
let request = {
|
||||
let mut stream_buffer = StreamPacketBuffer::new(&mut stream);
|
||||
return_or_report!(
|
||||
DnsPacket::from_buffer(&mut stream_buffer),
|
||||
"Failed to read query packet"
|
||||
)
|
||||
};
|
||||
|
||||
let mut res_buffer = VectorPacketBuffer::new();
|
||||
|
||||
let mut packet = execute_query(context.clone(), &request);
|
||||
ignore_or_report!(
|
||||
packet.write(&mut res_buffer, 0xFFFF),
|
||||
"Failed to write packet to buffer"
|
||||
);
|
||||
|
||||
// As is the case for incoming queries, we need to send a 2 byte length
|
||||
// value before handing of the actual packet.
|
||||
let len = res_buffer.pos();
|
||||
ignore_or_report!(
|
||||
write_packet_length(&mut stream, len),
|
||||
"Failed to write packet size"
|
||||
);
|
||||
|
||||
// Now we can go ahead and write the actual packet
|
||||
let data = return_or_report!(
|
||||
res_buffer.get_range(0, len),
|
||||
"Failed to get packet data"
|
||||
);
|
||||
|
||||
ignore_or_report!(stream.write(data), "Failed to write response packet");
|
||||
|
||||
ignore_or_report!(stream.shutdown(Shutdown::Both), "Failed to shutdown socket");
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
let _ = Builder::new()
|
||||
.name("DnsTcpServer-incoming".into())
|
||||
.spawn(move || {
|
||||
for wrap_stream in socket.incoming() {
|
||||
let stream = match wrap_stream {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
println!("Failed to accept TCP connection: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Hand it off to a worker thread
|
||||
let thread_no = random::<usize>() % self.thread_count;
|
||||
match self.senders[thread_no].send(stream) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
println!(
|
||||
"Failed to send TCP request for processing on thread {}: {}",
|
||||
thread_no, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::dns::protocol::{
|
||||
DnsPacket, DnsQuestion, DnsRecord, QueryType, ResultCode, TransientTtl,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
use crate::dns::context::tests::create_test_context;
|
||||
use crate::dns::context::ResolveStrategy;
|
||||
|
||||
fn build_query(qname: &str, qtype: QueryType) -> DnsPacket {
|
||||
let mut query_packet = DnsPacket::new();
|
||||
query_packet.header.recursion_desired = true;
|
||||
|
||||
query_packet
|
||||
.questions
|
||||
.push(DnsQuestion::new(qname.into(), qtype));
|
||||
|
||||
query_packet
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_query() {
|
||||
// Construct a context to execute some queries successfully
|
||||
let mut context = create_test_context(Box::new(|qname, qtype, _, _| {
|
||||
let mut packet = DnsPacket::new();
|
||||
|
||||
if qname == "google.com" {
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "google.com".to_string(),
|
||||
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
} else if qname == "www.facebook.com" && qtype == QueryType::CNAME {
|
||||
packet.answers.push(DnsRecord::CNAME {
|
||||
domain: "www.facebook.com".to_string(),
|
||||
host: "cdn.facebook.com".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "cdn.facebook.com".to_string(),
|
||||
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
} else if qname == "www.microsoft.com" && qtype == QueryType::CNAME {
|
||||
packet.answers.push(DnsRecord::CNAME {
|
||||
domain: "www.microsoft.com".to_string(),
|
||||
host: "cdn.microsoft.com".to_string(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
} else if qname == "cdn.microsoft.com" && qtype == QueryType::A {
|
||||
packet.answers.push(DnsRecord::A {
|
||||
domain: "cdn.microsoft.com".to_string(),
|
||||
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
|
||||
ttl: TransientTtl(3600),
|
||||
});
|
||||
} else {
|
||||
packet.header.rescode = ResultCode::NXDOMAIN;
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
}));
|
||||
|
||||
match Arc::get_mut(&mut context) {
|
||||
Some(mut ctx) => {
|
||||
ctx.resolve_strategy = ResolveStrategy::Forward {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 53,
|
||||
};
|
||||
}
|
||||
None => panic!(),
|
||||
}
|
||||
|
||||
// A successful resolve
|
||||
{
|
||||
let res = execute_query(context.clone(), &build_query("google.com", QueryType::A));
|
||||
assert_eq!(1, res.answers.len());
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("google.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
};
|
||||
|
||||
// A successful resolve, that also resolves a CNAME without recursive lookup
|
||||
{
|
||||
let res = execute_query(
|
||||
context.clone(),
|
||||
&build_query("www.facebook.com", QueryType::CNAME),
|
||||
);
|
||||
assert_eq!(2, res.answers.len());
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::CNAME { ref domain, .. } => {
|
||||
assert_eq!("www.facebook.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
|
||||
match res.answers[1] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("cdn.facebook.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
};
|
||||
|
||||
// A successful resolve, that also resolves a CNAME through recursive lookup
|
||||
{
|
||||
let res = execute_query(
|
||||
context.clone(),
|
||||
&build_query("www.microsoft.com", QueryType::CNAME),
|
||||
);
|
||||
assert_eq!(2, res.answers.len());
|
||||
|
||||
match res.answers[0] {
|
||||
DnsRecord::CNAME { ref domain, .. } => {
|
||||
assert_eq!("www.microsoft.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
|
||||
match res.answers[1] {
|
||||
DnsRecord::A { ref domain, .. } => {
|
||||
assert_eq!("cdn.microsoft.com", domain);
|
||||
}
|
||||
_ => panic!(),
|
||||
}
|
||||
};
|
||||
|
||||
// An unsuccessful resolve, but without any error
|
||||
{
|
||||
let res = execute_query(context.clone(), &build_query("yahoo.com", QueryType::A));
|
||||
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
|
||||
assert_eq!(0, res.answers.len());
|
||||
};
|
||||
|
||||
// Disable recursive resolves to generate a failure
|
||||
match Arc::get_mut(&mut context) {
|
||||
Some(mut ctx) => {
|
||||
ctx.allow_recursive = false;
|
||||
}
|
||||
None => panic!(),
|
||||
}
|
||||
|
||||
// This should generate an error code, since recursive resolves are
|
||||
// no longer allowed
|
||||
{
|
||||
let res = execute_query(context.clone(), &build_query("yahoo.com", QueryType::A));
|
||||
assert_eq!(ResultCode::REFUSED, res.header.rescode);
|
||||
assert_eq!(0, res.answers.len());
|
||||
};
|
||||
|
||||
// Send a query without a question, which should fail with an error code
|
||||
{
|
||||
let query_packet = DnsPacket::new();
|
||||
let res = execute_query(context.clone(), &query_packet);
|
||||
assert_eq!(ResultCode::FORMERR, res.header.rescode);
|
||||
assert_eq!(0, res.answers.len());
|
||||
};
|
||||
|
||||
// Now construct a context where the dns client will return a failure
|
||||
let mut context2 = create_test_context(Box::new(|_, _, _, _| {
|
||||
Err(crate::dns::client::ClientError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"Fail",
|
||||
)))
|
||||
}));
|
||||
|
||||
match Arc::get_mut(&mut context2) {
|
||||
Some(mut ctx) => {
|
||||
ctx.resolve_strategy = ResolveStrategy::Forward {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 53,
|
||||
};
|
||||
}
|
||||
None => panic!(),
|
||||
}
|
||||
|
||||
// We expect this to set the server failure rescode
|
||||
{
|
||||
let res = execute_query(context2.clone(), &build_query("yahoo.com", QueryType::A));
|
||||
assert_eq!(ResultCode::SERVFAIL, res.header.rescode);
|
||||
assert_eq!(0, res.answers.len());
|
||||
};
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue