You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Alfis/src/dns/cache.rs

383 lines
12 KiB
Rust

//! a threadsafe cache for DNS information
extern crate serde;
use std::clone::Clone;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use chrono::*;
use derive_more::{Display, Error, From};
use serde::{Deserialize, Serialize};
use crate::dns::protocol::{DnsPacket, DnsRecord, QueryType, ResultCode};
#[derive(Debug, Display, From, Error)]
pub enum CacheError {
Io(std::io::Error),
PoisonedLock
}
type Result<T> = std::result::Result<T, CacheError>;
pub enum CacheState {
PositiveCache,
NegativeCache,
NotCached
}
#[derive(Clone, Eq, Debug, Serialize, Deserialize)]
pub struct RecordEntry {
pub record: DnsRecord,
pub timestamp: DateTime<Local>
}
impl PartialEq<RecordEntry> for RecordEntry {
fn eq(&self, other: &RecordEntry) -> bool {
self.record == other.record
}
}
impl Hash for RecordEntry {
fn hash<H>(&self, state: &mut H) where H: Hasher {
self.record.hash(state);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum RecordSet {
NoRecords { qtype: QueryType, ttl: u32, timestamp: DateTime<Local> },
Records { qtype: QueryType, records: HashSet<RecordEntry> }
}
#[derive(Clone, Debug)]
pub struct DomainEntry {
pub domain: String,
pub record_types: HashMap<QueryType, RecordSet>,
pub hits: u32,
pub updates: u32
}
impl DomainEntry {
pub fn new(domain: String) -> DomainEntry {
DomainEntry { domain, record_types: HashMap::new(), hits: 0, updates: 0 }
}
pub fn store_nxdomain(&mut self, qtype: QueryType, ttl: u32) {
self.updates += 1;
let new_set = RecordSet::NoRecords { qtype, ttl, timestamp: Local::now() };
self.record_types.insert(qtype, new_set);
}
pub fn store_record(&mut self, rec: &DnsRecord) {
self.updates += 1;
let entry = RecordEntry { record: rec.clone(), timestamp: Local::now() };
if let Some(&mut RecordSet::Records { ref mut records, .. }) = self.record_types.get_mut(&rec.get_querytype()) {
if records.contains(&entry) {
records.remove(&entry);
}
records.insert(entry);
return;
}
let mut records = HashSet::new();
records.insert(entry);
let new_set = RecordSet::Records { qtype: rec.get_querytype(), records };
self.record_types.insert(rec.get_querytype(), new_set);
}
pub fn get_cache_state(&self, qtype: QueryType) -> CacheState {
match self.record_types.get(&qtype) {
Some(&RecordSet::Records { ref records, .. }) => {
let now = Local::now();
let mut valid_count = 0;
for entry in records {
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
let expires = entry.timestamp + ttl_offset;
if expires < now {
continue;
}
if entry.record.get_querytype() == qtype {
valid_count += 1;
}
}
if valid_count > 0 {
CacheState::PositiveCache
} else {
CacheState::NotCached
}
}
Some(&RecordSet::NoRecords { ttl, timestamp, .. }) => {
let now = Local::now();
let ttl_offset = Duration::seconds(ttl as i64);
let expires = timestamp + ttl_offset;
if expires < now {
CacheState::NotCached
} else {
CacheState::NegativeCache
}
}
None => CacheState::NotCached
}
}
pub fn fill_queryresult(&self, qtype: QueryType, result_vec: &mut Vec<DnsRecord>) {
let now = Local::now();
let current_set = match self.record_types.get(&qtype) {
Some(x) => x,
None => return
};
if let RecordSet::Records { ref records, .. } = *current_set {
for entry in records {
let ttl_offset = Duration::seconds(entry.record.get_ttl() as i64);
let expires = entry.timestamp + ttl_offset;
if expires < now {
continue;
}
if entry.record.get_querytype() == qtype {
result_vec.push(entry.record.clone());
}
}
}
}
}
#[derive(Default)]
pub struct Cache {
domain_entries: BTreeMap<String, Arc<DomainEntry>>
}
impl Cache {
pub fn new() -> Cache {
Cache { domain_entries: BTreeMap::new() }
}
fn get_cache_state(&mut self, qname: &str, qtype: QueryType) -> CacheState {
match self.domain_entries.get(qname) {
Some(x) => x.get_cache_state(qtype),
None => CacheState::NotCached
}
}
fn fill_queryresult(&mut self, qname: &str, qtype: QueryType, result_vec: &mut Vec<DnsRecord>, increment_stats: bool) {
if let Some(domain_entry) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
if increment_stats {
domain_entry.hits += 1
}
domain_entry.fill_queryresult(qtype, result_vec);
}
}
pub fn lookup(&mut self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
match self.get_cache_state(qname, qtype) {
CacheState::PositiveCache => {
let mut qr = DnsPacket::new();
self.fill_queryresult(qname, qtype, &mut qr.answers, true);
self.fill_queryresult(qname, QueryType::NS, &mut qr.authorities, false);
Some(qr)
}
CacheState::NegativeCache => {
let mut qr = DnsPacket::new();
qr.header.rescode = ResultCode::NXDOMAIN;
Some(qr)
}
CacheState::NotCached => None
}
}
pub fn store(&mut self, records: &[DnsRecord]) {
for rec in records {
let domain = match rec.get_domain() {
Some(x) => x,
None => continue
};
if let Some(ref mut rs) = self.domain_entries.get_mut(&domain).and_then(Arc::get_mut) {
rs.store_record(rec);
continue;
}
let mut rs = DomainEntry::new(domain.clone());
rs.store_record(rec);
self.domain_entries.insert(domain.clone(), Arc::new(rs));
}
}
pub fn store_nxdomain(&mut self, qname: &str, qtype: QueryType, ttl: u32) {
if let Some(ref mut rs) = self.domain_entries.get_mut(qname).and_then(Arc::get_mut) {
rs.store_nxdomain(qtype, ttl);
return;
}
let mut rs = DomainEntry::new(qname.to_string());
rs.store_nxdomain(qtype, ttl);
self.domain_entries.insert(qname.to_string(), Arc::new(rs));
}
}
#[derive(Default)]
pub struct SynchronizedCache {
pub cache: RwLock<Cache>
}
impl SynchronizedCache {
pub fn new() -> SynchronizedCache {
SynchronizedCache { cache: RwLock::new(Cache::new()) }
}
pub fn list(&self) -> Result<Vec<Arc<DomainEntry>>> {
let cache = self.cache.read().map_err(|_| CacheError::PoisonedLock)?;
let mut list = Vec::new();
for rs in cache.domain_entries.values() {
list.push(rs.clone());
}
Ok(list)
}
pub fn lookup(&self, qname: &str, qtype: QueryType) -> Option<DnsPacket> {
let mut cache = match self.cache.write() {
Ok(x) => x,
Err(_) => return None
};
cache.lookup(qname, qtype)
}
pub fn store(&self, records: &[DnsRecord]) -> Result<()> {
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
cache.store(records);
Ok(())
}
pub fn store_nxdomain(&self, qname: &str, qtype: QueryType, ttl: u32) -> Result<()> {
let mut cache = self.cache.write().map_err(|_| CacheError::PoisonedLock)?;
cache.store_nxdomain(qname, qtype, ttl);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::protocol::{DnsRecord, QueryType, ResultCode, TransientTtl};
#[test]
fn test_cache() {
let mut cache = Cache::new();
// Verify that no data is returned when nothing is present
if cache.lookup("www.google.com", QueryType::A).is_some() {
panic!()
}
// Register a negative cache entry
cache.store_nxdomain("www.google.com", QueryType::A, 3600);
// Verify that we get a response, with the NXDOMAIN flag set
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
assert_eq!(ResultCode::NXDOMAIN, packet.header.rescode);
}
// Register a negative cache entry with no TTL
cache.store_nxdomain("www.yahoo.com", QueryType::A, 0);
std::thread::sleep(core::time::Duration::from_secs(1));
// And check that no such result is actually returned, since it's expired
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
panic!()
}
// Now add some actual records
let mut records = Vec::new();
records.push(DnsRecord::A {
domain: "www.google.com".to_string(),
addr: "127.0.0.1".parse().unwrap(),
ttl: TransientTtl(3600)
});
records.push(DnsRecord::A {
domain: "www.yahoo.com".to_string(),
addr: "127.0.0.2".parse().unwrap(),
ttl: TransientTtl(0)
});
records.push(DnsRecord::CNAME {
domain: "www.microsoft.com".to_string(),
host: "www.somecdn.com".to_string(),
ttl: TransientTtl(3600)
});
cache.store(&records);
// Test for successful lookup
if let Some(packet) = cache.lookup("www.google.com", QueryType::A) {
assert_eq!(records[0], packet.answers[0]);
} else {
panic!();
}
// Test for failed lookup, since no CNAME's are known for this domain
if cache.lookup("www.google.com", QueryType::CNAME).is_some() {
panic!();
}
// Check for successful CNAME lookup
if let Some(packet) = cache.lookup("www.microsoft.com", QueryType::CNAME) {
assert_eq!(records[2], packet.answers[0]);
} else {
panic!();
}
// This lookup should fail, since it has expired due to the 0 second TTL
if cache.lookup("www.yahoo.com", QueryType::A).is_some() {
panic!();
}
let mut records2 = Vec::new();
records2.push(DnsRecord::A {
domain: "www.yahoo.com".to_string(),
addr: "127.0.0.2".parse().unwrap(),
ttl: TransientTtl(3600)
});
cache.store(&records2);
// And now it should succeed, since the record has been store
if !cache.lookup("www.yahoo.com", QueryType::A).is_some() {
panic!();
}
// Check stat counter behavior
assert_eq!(3, cache.domain_entries.len());
assert_eq!(1, cache.domain_entries.get(&"www.google.com".to_string()).unwrap().hits);
assert_eq!(2, cache.domain_entries.get(&"www.google.com".to_string()).unwrap().updates);
assert_eq!(1, cache.domain_entries.get(&"www.yahoo.com".to_string()).unwrap().hits);
assert_eq!(3, cache.domain_entries.get(&"www.yahoo.com".to_string()).unwrap().updates);
assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().updates);
assert_eq!(1, cache.domain_entries.get(&"www.microsoft.com".to_string()).unwrap().hits);
}
}