You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Alfis/src/dns/server.rs

556 lines
21 KiB
Rust

//! UDP and TCP server implementations for DNS
use std::collections::VecDeque;
use std::io::Write;
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
use std::sync::atomic::Ordering;
use std::sync::mpsc::{channel, Sender};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::Builder;
use derive_more::{Display, Error, From};
use log::{debug, error, warn};
use rand::random;
use crate::dns::buffer::{BytePacketBuffer, PacketBuffer, StreamPacketBuffer, VectorPacketBuffer};
use crate::dns::context::ServerContext;
use crate::dns::netutil::{read_packet_length, write_packet_length};
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
use crate::dns::resolve::DnsResolver;
#[derive(Debug, Display, From, Error)]
pub enum ServerError {
Io(std::io::Error)
}
type Result<T> = std::result::Result<T, ServerError>;
macro_rules! return_or_report {
( $x:expr, $message:expr ) => {
match $x {
Ok(res) => res,
Err(_) => {
debug!($message);
return;
}
}
};
}
macro_rules! ignore_or_report {
( $x:expr, $message:expr ) => {
match $x {
Ok(_) => {}
Err(_) => {
debug!($message);
return;
}
};
};
}
/// Common trait for DNS servers
pub trait DnsServer {
/// Initialize the server and start listenening
///
/// This method should _NOT_ block. Rather, servers are expected to spawn a new
/// thread to handle requests and return immediately.
fn run_server(self) -> Result<()>;
}
/// Utility function for resolving domains referenced in for example CNAME or SRV
/// records. This usually spares the client from having to perform additional lookups.
fn resolve_cnames(lookup_list: &[DnsRecord], results: &mut Vec<DnsPacket>, resolver: &mut Box<dyn DnsResolver>, depth: u16) {
if depth > 10 {
return;
}
for ref rec in lookup_list {
match **rec {
DnsRecord::CNAME { ref host, .. } | DnsRecord::SRV { ref host, .. } => {
if let Ok(result2) = resolver.resolve(host, QueryType::A, true) {
let new_unmatched = result2.get_unresolved_cnames();
results.push(result2);
resolve_cnames(&new_unmatched, results, resolver, depth + 1);
}
if let Ok(result2) = resolver.resolve(host, QueryType::AAAA, true) {
let new_unmatched = result2.get_unresolved_cnames();
results.push(result2);
resolve_cnames(&new_unmatched, results, resolver, depth + 1);
}
}
_ => {}
}
}
}
/// Perform the actual work for a query
///
/// Incoming requests are validated to make sure they are well formed and adhere
/// to the server configuration. If so, the request will be passed on to the
/// active resolver and a query will be performed. It will also resolve some
/// possible references within the query, such as CNAME hosts.
///
/// This function will always return a valid packet, even if the request could not
/// be performed, since we still want to send something back to the client.
pub fn execute_query(context: Arc<ServerContext>, request: &DnsPacket) -> DnsPacket {
let mut packet = DnsPacket::new();
packet.header.id = request.header.id;
packet.header.recursion_available = context.allow_recursive;
packet.header.recursion_desired = request.header.recursion_desired;
packet.header.response = true;
if request.header.recursion_desired && !context.allow_recursive {
packet.header.rescode = ResultCode::REFUSED;
} else if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR;
} else {
let mut results = Vec::new();
let question = &request.questions[0];
packet.questions.push(question.clone());
let mut resolver = context.create_resolver(Arc::clone(&context));
let rescode = match resolver.resolve(&question.name, question.qtype, request.header.recursion_desired) {
Ok(result) => {
let rescode = result.header.rescode;
if result.header.authoritative_answer {
packet.header.authoritative_answer = true;
}
let unmatched = result.get_unresolved_cnames();
results.push(result);
resolve_cnames(&unmatched, &mut results, &mut resolver, 0);
rescode
}
Err(err) => {
error!("Failed to resolve {:?} {}: {:?}", question.qtype, question.name, err);
ResultCode::SERVFAIL
}
};
packet.header.rescode = rescode;
for result in results {
for rec in result.answers {
packet.answers.push(rec);
}
for rec in result.authorities {
packet.authorities.push(rec);
}
for rec in result.resources {
packet.resources.push(rec);
}
}
}
packet
}
/// The UDP server
///
/// Accepts DNS queries through UDP, and uses the `ServerContext` to determine
/// how to service the request. Packets are read on a single thread, after which
/// a new thread is spawned to service the request asynchronously.
pub struct DnsUdpServer {
context: Arc<ServerContext>,
request_queue: Arc<Mutex<VecDeque<(SocketAddr, DnsPacket)>>>,
request_cond: Arc<Condvar>,
thread_count: usize
}
impl DnsUdpServer {
pub fn new(context: Arc<ServerContext>, thread_count: usize) -> DnsUdpServer {
DnsUdpServer { context, request_queue: Arc::new(Mutex::new(VecDeque::new())), request_cond: Arc::new(Condvar::new()), thread_count }
}
}
impl DnsServer for DnsUdpServer {
/// Launch the server
///
/// This method takes ownership of the server, preventing the method from being called multiple times.
fn run_server(self) -> Result<()> {
// Bind the socket
let socket = UdpSocket::bind(self.context.dns_listen.as_str())
.expect(&format!("Cannot start DNS server on {}! Change listen address in config!", self.context.dns_listen.as_str()));
// Spawn threads for handling requests
for thread_id in 0..self.thread_count {
let socket_clone = match socket.try_clone() {
Ok(x) => x,
Err(e) => {
warn!("Failed to clone socket when starting UDP server: {:?}", e);
continue;
}
};
let context = Arc::clone(&self.context);
let request_cond = self.request_cond.clone();
let request_queue = self.request_queue.clone();
let name = "DnsUdpServer-request-".to_string() + &thread_id.to_string();
let _ = Builder::new().name(name).spawn(move || {
loop {
// Acquire lock, and wait on the condition until data is
// available. Then proceed with popping an entry of the queue.
let (src, request) = match request_queue
.lock()
.ok()
.and_then(|x| request_cond.wait(x).ok())
.and_then(|mut x| x.pop_front())
{
Some(x) => x,
None => {
debug!("Not expected to happen!");
continue;
}
};
let mut size_limit = 512;
// Check for EDNS
if request.resources.len() == 1 {
if let DnsRecord::OPT { packet_len, .. } = request.resources[0] {
size_limit = packet_len as usize;
}
}
// Create a response buffer, and ask the context for an appropriate resolver
let mut res_buffer = VectorPacketBuffer::new();
let mut packet = execute_query(Arc::clone(&context), &request);
let _ = packet.write(&mut res_buffer, size_limit);
// Fire off the response
let len = res_buffer.pos();
let data = return_or_report!(res_buffer.get_range(0, len), "Failed to get buffer data");
ignore_or_report!(socket_clone.send_to(data, src), "Failed to send response packet");
}
})?;
}
// Start servicing requests
let _ = Builder::new()
.name("DnsUdpServer-incoming".into())
.spawn(move || {
loop {
let _ = self.context.statistics.udp_query_count.fetch_add(1, Ordering::Release);
// Read a query packet
let mut req_buffer = BytePacketBuffer::new();
let (_, src) = match socket.recv_from(&mut req_buffer.buf) {
Ok(x) => x,
Err(err) => {
if let Some(code) = err.raw_os_error() {
if code == 10004 {
debug!("UDP service loop has finished");
break;
}
}
debug!("Failed to read from UDP socket: {:?}", err);
continue;
}
};
// Parse it
let request = match DnsPacket::from_buffer(&mut req_buffer) {
Ok(x) => x,
Err(e) => {
debug!("Failed to parse UDP query packet: {:?}", e);
continue;
}
};
// Acquire lock, add request to queue, and notify waiting threads using the condition.
match self.request_queue.lock() {
Ok(mut queue) => {
queue.push_back((src, request));
self.request_cond.notify_one();
}
Err(e) => {
debug!("Failed to send UDP request for processing: {}", e);
}
}
}
})?;
Ok(())
}
}
/// TCP DNS server
pub struct DnsTcpServer {
context: Arc<ServerContext>,
senders: Vec<Sender<TcpStream>>,
thread_count: usize
}
impl DnsTcpServer {
pub fn new(context: Arc<ServerContext>, thread_count: usize) -> DnsTcpServer {
DnsTcpServer { context, senders: Vec::new(), thread_count }
}
}
impl DnsServer for DnsTcpServer {
fn run_server(mut self) -> Result<()> {
let socket = TcpListener::bind(self.context.dns_listen.as_str())
.expect(&format!("Cannot start DNS server on {}! Change listen address in config!", self.context.dns_listen.as_str()));
// Spawn threads for handling requests, and create the channels
for thread_id in 0..self.thread_count {
let (tx, rx) = channel();
self.senders.push(tx);
let context = Arc::clone(&self.context);
let name = "DnsTcpServer-request-".to_string() + &thread_id.to_string();
let _ = Builder::new().name(name).spawn(move || {
loop {
let mut stream = match rx.recv() {
Ok(x) => x,
Err(_) => continue
};
let _ = context.statistics.tcp_query_count.fetch_add(1, Ordering::Release);
// When DNS packets are sent over TCP, they're prefixed with a two byte
// length. We don't really need to know the length in advance, so we
// just move past it and continue reading as usual
ignore_or_report!(read_packet_length(&mut stream), "Failed to read query packet length");
let request = {
let mut stream_buffer = StreamPacketBuffer::new(&mut stream);
return_or_report!(DnsPacket::from_buffer(&mut stream_buffer), "Failed to read query packet")
};
let mut res_buffer = VectorPacketBuffer::new();
let mut packet = execute_query(Arc::clone(&context), &request);
ignore_or_report!(packet.write(&mut res_buffer, 0xFFFF), "Failed to write packet to buffer");
// As is the case for incoming queries, we need to send a 2 byte length
// value before handing of the actual packet.
let len = res_buffer.pos();
ignore_or_report!(write_packet_length(&mut stream, len), "Failed to write packet size");
// Now we can go ahead and write the actual packet
let data = return_or_report!(res_buffer.get_range(0, len), "Failed to get packet data");
ignore_or_report!(stream.write(data), "Failed to write response packet");
ignore_or_report!(stream.shutdown(Shutdown::Both), "Failed to shutdown socket");
}
})?;
}
let _ = Builder::new()
.name("DnsTcpServer-incoming".into())
.spawn(move || {
for wrap_stream in socket.incoming() {
let stream = match wrap_stream {
Ok(stream) => stream,
Err(err) => {
if let Some(code) = err.raw_os_error() {
if code == 10004 {
debug!("TCP service loop has finished");
break;
}
}
warn!("Failed to accept TCP connection: {:?}", err);
continue;
}
};
// Hand it off to a worker thread
let thread_no = random::<usize>() % self.thread_count;
match self.senders[thread_no].send(stream) {
Ok(_) => {}
Err(e) => {
warn!("Failed to send TCP request for processing on thread {}: {}", thread_no, e);
}
}
}
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use std::sync::Arc;
use super::*;
use crate::dns::context::tests::create_test_context;
use crate::dns::context::ResolveStrategy;
use crate::dns::protocol::{DnsPacket, DnsQuestion, DnsRecord, QueryType, ResultCode, TransientTtl};
fn build_query(qname: &str, qtype: QueryType) -> DnsPacket {
let mut query_packet = DnsPacket::new();
query_packet.header.recursion_desired = true;
query_packet.questions.push(DnsQuestion::new(qname.into(), qtype));
query_packet
}
#[test]
fn test_execute_query() {
// Construct a context to execute some queries successfully
let mut context = create_test_context(Box::new(|qname, qtype, _, _| {
let mut packet = DnsPacket::new();
if qname == "google.com" {
packet.answers.push(DnsRecord::A {
domain: "google.com".to_string(),
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
ttl: TransientTtl(3600)
});
} else if qname == "www.facebook.com" && qtype == QueryType::CNAME {
packet.answers.push(DnsRecord::CNAME {
domain: "www.facebook.com".to_string(),
host: "cdn.facebook.com".to_string(),
ttl: TransientTtl(3600)
});
packet.answers.push(DnsRecord::A {
domain: "cdn.facebook.com".to_string(),
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
ttl: TransientTtl(3600)
});
} else if qname == "www.microsoft.com" && qtype == QueryType::CNAME {
packet.answers.push(DnsRecord::CNAME {
domain: "www.microsoft.com".to_string(),
host: "cdn.microsoft.com".to_string(),
ttl: TransientTtl(3600)
});
} else if qname == "cdn.microsoft.com" && qtype == QueryType::A {
packet.answers.push(DnsRecord::A {
domain: "cdn.microsoft.com".to_string(),
addr: "127.0.0.1".parse::<Ipv4Addr>().unwrap(),
ttl: TransientTtl(3600)
});
} else {
packet.header.rescode = ResultCode::NXDOMAIN;
}
Ok(packet)
}));
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward { upstreams: vec![String::from("127.0.0.1:53")] };
}
None => panic!()
}
// A successful resolve
{
let res = execute_query(Arc::clone(&context), &build_query("google.com", QueryType::A));
assert_eq!(1, res.answers.len());
match res.answers[0] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("google.com", domain);
}
_ => panic!()
}
};
// A successful resolve, that also resolves a CNAME without recursive lookup
{
let res = execute_query(Arc::clone(&context), &build_query("www.facebook.com", QueryType::CNAME));
assert_eq!(2, res.answers.len());
match res.answers[0] {
DnsRecord::CNAME { ref domain, .. } => {
assert_eq!("www.facebook.com", domain);
}
_ => panic!()
}
match res.answers[1] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("cdn.facebook.com", domain);
}
_ => panic!()
}
};
// A successful resolve, that also resolves a CNAME through recursive lookup
{
let res = execute_query(Arc::clone(&context), &build_query("www.microsoft.com", QueryType::CNAME));
assert_eq!(2, res.answers.len());
match res.answers[0] {
DnsRecord::CNAME { ref domain, .. } => {
assert_eq!("www.microsoft.com", domain);
}
_ => panic!()
}
match res.answers[1] {
DnsRecord::A { ref domain, .. } => {
assert_eq!("cdn.microsoft.com", domain);
}
_ => panic!()
}
};
// An unsuccessful resolve, but without any error
{
let res = execute_query(Arc::clone(&context), &build_query("yahoo.com", QueryType::A));
assert_eq!(ResultCode::NXDOMAIN, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Disable recursive resolves to generate a failure
match Arc::get_mut(&mut context) {
Some(mut ctx) => {
ctx.allow_recursive = false;
}
None => panic!()
}
// This should generate an error code, since recursive resolves are
// no longer allowed
{
let res = execute_query(Arc::clone(&context), &build_query("yahoo.com", QueryType::A));
assert_eq!(ResultCode::REFUSED, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Send a query without a question, which should fail with an error code
{
let query_packet = DnsPacket::new();
let res = execute_query(Arc::clone(&context), &query_packet);
assert_eq!(ResultCode::FORMERR, res.header.rescode);
assert_eq!(0, res.answers.len());
};
// Now construct a context where the dns client will return a failure
let mut context2 = create_test_context(Box::new(|_, _, _, _| {
Err(crate::dns::client::ClientError::Io(std::io::Error::new(std::io::ErrorKind::NotFound, "Fail")))
}));
match Arc::get_mut(&mut context2) {
Some(mut ctx) => {
ctx.resolve_strategy = ResolveStrategy::Forward { upstreams: vec![String::from("127.0.0.1:53")] };
}
None => panic!()
}
// We expect this to set the server failure rescode
{
let res = execute_query(context2.clone(), &build_query("yahoo.com", QueryType::A));
assert_eq!(ResultCode::SERVFAIL, res.header.rescode);
assert_eq!(0, res.answers.len());
};
}
}