Retry truncated responses over TCP

pull/5/head
Frank Denis 5 years ago
parent 35819a2375
commit 0b3eabb488

@ -18,12 +18,12 @@ const DNS_TYPE_TXT: u16 = 16;
const DNS_CLASS_INET: u16 = 1;
#[inline]
fn qdcount(packet: &[u8]) -> u16 {
pub fn qdcount(packet: &[u8]) -> u16 {
BigEndian::read_u16(&packet[4..])
}
#[inline]
fn ancount(packet: &[u8]) -> u16 {
pub fn ancount(packet: &[u8]) -> u16 {
BigEndian::read_u16(&packet[6..])
}
@ -41,7 +41,7 @@ fn nscount(packet: &[u8]) -> u16 {
}
#[inline]
fn arcount(packet: &[u8]) -> u16 {
pub fn arcount(packet: &[u8]) -> u16 {
BigEndian::read_u16(&packet[10..])
}
@ -81,6 +81,16 @@ pub fn truncate(packet: &mut [u8]) {
);
}
#[inline]
pub fn is_response(packet: &[u8]) -> bool {
BigEndian::read_u16(&packet[DNS_OFFSET_FLAGS..]) & DNS_FLAGS_QR == DNS_FLAGS_QR
}
#[inline]
pub fn is_truncated(packet: &[u8]) -> bool {
BigEndian::read_u16(&packet[DNS_OFFSET_FLAGS..]) & DNS_FLAGS_TC == DNS_FLAGS_TC
}
pub fn qname(packet: &[u8]) -> Result<Vec<u8>, Error> {
assert!(std::usize::MAX > 0xffff);
ensure!(qdcount(packet) == 1, "Unexpected query count");

@ -66,6 +66,7 @@ enum ClientCtx {
}
async fn respond_to_query(client_ctx: ClientCtx, packet: Vec<u8>) -> Result<(), Error> {
ensure!(dns::is_response(&packet), "Packet is not a response");
match client_ctx {
ClientCtx::Udp(client_ctx) => {
let net_udp_socket = client_ctx.net_udp_socket;
@ -79,6 +80,7 @@ async fn respond_to_query(client_ctx: ClientCtx, packet: Vec<u8>) -> Result<(),
BigEndian::write_u16(&mut binlen[..], packet_len as u16);
client_connection.write_all(&binlen).await?;
client_connection.write_all(&packet).await?;
client_connection.flush();
}
}
Ok(())
@ -89,27 +91,57 @@ async fn handle_client_query(
client_ctx: ClientCtx,
mut packet: Vec<u8>,
) -> Result<(), Error> {
ensure!(packet.len() >= DNSCRYPT_QUERY_MIN_SIZE, "Short packet");
ensure!(dns::qdcount(&packet) == 1, "No question");
ensure!(
!dns::is_response(&packet),
"Question expected, but got a response instead"
);
if let Some(synth_packet) =
serve_certificates(&packet, &globals.provider_name, &globals.dnscrypt_certs)?
{
return respond_to_query(client_ctx, synth_packet).await;
}
set_edns_max_payload_size(&mut packet, DNS_MAX_PACKET_SIZE as u16)?;
let original_tid = dns::tid(&packet);
let tid = random();
dns::set_tid(&mut packet, tid);
let mut ext_socket = UdpSocket::bind(&globals.external_addr).await?;
ext_socket.connect(&globals.upstream_addr).await?;
set_edns_max_payload_size(&mut packet, DNS_MAX_PACKET_SIZE as u16)?;
ext_socket.send(&packet).await.unwrap();
let mut response = vec![0u8; DNS_MAX_PACKET_SIZE];
let response_len = ext_socket.recv(&mut response[..]).await?;
ensure!(response_len > DNS_HEADER_SIZE, "Short packet");
response.truncate(response_len);
ensure!(dns::tid(&response) == tid, "Unexpected transaction ID");
ensure!(
dns::qname(&packet)? == dns::qname(&response)?,
"Unexpected query name in the response"
);
let mut response;
loop {
response = vec![0u8; DNS_MAX_PACKET_SIZE];
let response_len = ext_socket.recv(&mut response[..]).await?;
ensure!(response_len > DNS_HEADER_SIZE, "Short packet");
response.truncate(response_len);
if dns::tid(&response) == tid && dns::qname(&packet)? == dns::qname(&response)? {
break;
}
dbg!("Response collision");
}
if dns::is_truncated(&response) {
let mut ext_socket = TcpStream::connect(&globals.upstream_addr).await?;
ext_socket.set_nodelay(true)?;
let mut binlen = [0u8, 0];
BigEndian::write_u16(&mut binlen[..], packet.len() as u16);
ext_socket.write_all(&binlen).await?;
ext_socket.write_all(&packet).await?;
ext_socket.flush();
ext_socket.read_exact(&mut binlen).await?;
let response_len = BigEndian::read_u16(&binlen) as usize;
ensure!(
(DNS_HEADER_SIZE..=DNS_MAX_PACKET_SIZE).contains(&response_len),
"Unexpected response size"
);
response = vec![0u8; response_len];
ext_socket.read_exact(&mut response).await?;
ensure!(dns::tid(&response) == tid, "Unexpected transaction ID");
ensure!(
dns::qname(&packet)? == dns::qname(&response)?,
"Unexpected query name in the response"
);
}
dns::set_tid(&mut response, original_tid);
respond_to_query(client_ctx, response).await
}
@ -123,6 +155,7 @@ async fn tcp_acceptor(globals: Arc<Globals>, tcp_listener: TcpListener) -> Resul
Ok(client_connection) => client_connection,
Err(e) => bail!(e),
};
client_connection.set_nodelay(true)?;
let globals = globals.clone();
runtime.spawn(
async {

Loading…
Cancel
Save