aosp12/packages/modules/DnsResolver/doh.rs

589 lines
20 KiB
Rust

/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//! DoH backend for the Android DnsResolver module.
use anyhow::{anyhow, Context, Result};
use lazy_static::lazy_static;
use libc::{c_char, size_t, ssize_t};
use log::{debug, error, info, warn};
use quiche::h3;
use ring::rand::SecureRandom;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::os::unix::io::{AsRawFd, RawFd};
use std::str::FromStr;
use std::sync::Arc;
use std::{ptr, slice};
use tokio::net::UdpSocket;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{mpsc, oneshot};
use tokio::task;
use tokio::time::Duration;
use url::Url;
lazy_static! {
/// Tokio runtime used to perform doh-handler tasks.
static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
Builder::new_multi_thread()
.worker_threads(2)
.max_blocking_threads(1)
.enable_all()
.thread_name("doh-handler")
.build()
.expect("Failed to create tokio runtime")
);
}
const MAX_BUFFERED_CMD_SIZE: usize = 400;
const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000;
const MAX_CONCURRENT_STREAM_SIZE: u64 = 100;
const MAX_DATAGRAM_SIZE: usize = 1350;
const MAX_DATAGRAM_SIZE_U64: u64 = 1350;
const DOH_PORT: u16 = 443;
const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000;
const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
type SCID = [u8; quiche::MAX_CONN_ID_LEN];
type Query = Vec<u8>;
type Response = Vec<u8>;
type CmdSender = mpsc::Sender<Command>;
type CmdReceiver = mpsc::Receiver<Command>;
type QueryResponder = oneshot::Sender<Option<Response>>;
#[derive(Debug)]
enum Command {
DohQuery { query: Query, resp: QueryResponder },
}
/// Context for a running DoH engine.
pub struct DohDispatcher {
/// Used to submit queries to the I/O thread.
query_sender: CmdSender,
join_handle: task::JoinHandle<Result<()>>,
}
fn make_doh_udp_socket(ip_addr: &str, mark: u32) -> Result<std::net::UdpSocket> {
let sock_addr = SocketAddr::new(IpAddr::from_str(&ip_addr)?, DOH_PORT);
let bind_addr = match sock_addr {
std::net::SocketAddr::V4(_) => "0.0.0.0:0",
std::net::SocketAddr::V6(_) => "[::]:0",
};
let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
udp_sk.set_nonblocking(true)?;
mark_socket(udp_sk.as_raw_fd(), mark)?;
udp_sk.connect(sock_addr)?;
debug!("connecting to {:} from {:}", sock_addr, udp_sk.local_addr()?);
Ok(udp_sk)
}
// DoH dispatcher
impl DohDispatcher {
fn new(
url: &str,
ip_addr: &str,
mark: u32,
cert_path: Option<&str>,
) -> Result<Box<DohDispatcher>> {
// Setup socket
let udp_sk = make_doh_udp_socket(&ip_addr, mark)?;
DohDispatcher::new_with_socket(url, ip_addr, mark, cert_path, udp_sk)
}
fn new_with_socket(
url: &str,
ip_addr: &str,
mark: u32,
cert_path: Option<&str>,
udp_sk: std::net::UdpSocket,
) -> Result<Box<DohDispatcher>> {
let url = Url::parse(&url.to_string())?;
if url.domain().is_none() {
return Err(anyhow!("no domain"));
}
// Setup quiche config
let config = create_quiche_config(cert_path)?;
let h3_config = h3::Config::new()?;
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
ring::rand::SystemRandom::new().fill(&mut scid[..]).context("failed to generate scid")?;
let (cmd_sender, cmd_receiver) = mpsc::channel::<Command>(MAX_BUFFERED_CMD_SIZE);
debug!(
"Creating a doh handler task: url={}, ip_addr={}, mark={:#x}, scid {:x?}",
url, ip_addr, mark, &scid
);
let join_handle =
RUNTIME_STATIC.spawn(doh_handler(url, udp_sk, config, h3_config, scid, cmd_receiver));
Ok(Box::new(DohDispatcher { query_sender: cmd_sender, join_handle }))
}
fn query(&self, cmd: Command) -> Result<()> {
self.query_sender.blocking_send(cmd)?;
Ok(())
}
fn abort_handler(&self) {
self.join_handle.abort();
}
}
async fn doh_handler(
url: url::Url,
udp_sk: std::net::UdpSocket,
mut config: quiche::Config,
h3_config: h3::Config,
scid: SCID,
mut rx: CmdReceiver,
) -> Result<()> {
debug!("doh_handler: url={:?}", url);
let sk = UdpSocket::from_std(udp_sk)?;
let mut conn = quiche::connect(url.domain(), &scid, &mut config)?;
let mut quic_conn_start = std::time::Instant::now();
let mut h3_conn: Option<h3::Connection> = None;
let mut is_idle = false;
let mut buf = [0; 65535];
let mut query_map = HashMap::<u64, QueryResponder>::new();
let mut pending_cmds: Vec<Command> = Vec::new();
let mut ts = Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS);
loop {
tokio::select! {
size = sk.recv(&mut buf) => {
debug!("recv {:?} ", size);
match size {
Ok(size) => {
let processed = match conn.recv(&mut buf[..size]) {
Ok(l) => l,
Err(e) => {
error!("quic recv failed: {:?}", e);
continue;
}
};
debug!("processed {} bytes", processed);
},
Err(e) => {
error!("socket recv failed: {:?}", e);
continue;
},
};
}
Some(cmd) = rx.recv() => {
debug!("recv {:?}", cmd);
pending_cmds.push(cmd);
}
_ = tokio::time::sleep(ts) => {
conn.on_timeout();
debug!("quic connection timeout");
}
}
if conn.is_closed() {
// Show connection statistics after it's closed
if !is_idle {
info!("connection closed, {:?}, {:?}", quic_conn_start.elapsed(), conn.stats());
is_idle = true;
if !conn.is_established() {
error!("connection handshake timed out after {:?}", quic_conn_start.elapsed());
}
}
// If there is any pending query, resume the quic connection.
if !pending_cmds.is_empty() {
info!("still some pending queries but connection is not avaiable, resume it");
conn = quiche::connect(url.domain(), &scid, &mut config)?;
quic_conn_start = std::time::Instant::now();
h3_conn = None;
is_idle = false;
}
}
// Create a new HTTP/3 connection once the QUIC connection is established.
if conn.is_established() && h3_conn.is_none() {
info!("quic ready, creating h3 conn");
h3_conn = Some(quiche::h3::Connection::with_transport(&mut conn, &h3_config)?);
}
// Try to receive query answers from h3 connection.
if let Some(h3) = h3_conn.as_mut() {
recv_query(h3, &mut conn, &mut query_map).await;
}
// Update the next timeout of quic connection.
ts = conn.timeout().unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
info!("next connection timouts {:?}", ts);
// Process the pending queries
while !pending_cmds.is_empty() && conn.is_established() {
if let Some(cmd) = pending_cmds.pop() {
match cmd {
Command::DohQuery { query, resp } => {
match send_dns_query(&query, &url, &mut h3_conn, &mut conn) {
Ok(stream_id) => {
query_map.insert(stream_id, resp);
}
Err(e) => {
info!("failed to send query {}", e);
pending_cmds.push(Command::DohQuery { query, resp });
}
}
}
}
}
}
flush_tx(&sk, &mut conn).await.unwrap_or_else(|e| {
error!("flush error {:?} ", e);
});
}
}
fn send_dns_query(
query: &[u8],
url: &url::Url,
h3_conn: &mut Option<quiche::h3::Connection>,
mut conn: &mut quiche::Connection,
) -> Result<u64> {
let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
let mut path = String::from(url.path());
path.push_str("?dns=");
path.push_str(std::str::from_utf8(&query)?);
let _req = vec![
quiche::h3::Header::new(":method", "GET"),
quiche::h3::Header::new(":scheme", "https"),
quiche::h3::Header::new(
":authority",
url.host_str().ok_or_else(|| anyhow!("failed to get host"))?,
),
quiche::h3::Header::new(":path", &path),
quiche::h3::Header::new("user-agent", "quiche"),
quiche::h3::Header::new("accept", "application/dns-message"),
// TODO: is content-length required?
];
Ok(h3_conn.send_request(&mut conn, &_req, false /*fin*/)?)
}
async fn recv_query(
h3_conn: &mut h3::Connection,
mut conn: &mut quiche::Connection,
map: &mut HashMap<u64, QueryResponder>,
) {
// Process HTTP/3 events.
let mut buf = [0; MAX_DATAGRAM_SIZE];
loop {
match h3_conn.poll(&mut conn) {
Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
info!(
"got response headers {:?} on stream id {} has_body {}",
list, stream_id, has_body
);
}
Ok((stream_id, quiche::h3::Event::Data)) => {
debug!("quiche::h3::Event::Data");
if let Ok(read) = h3_conn.recv_body(&mut conn, stream_id, &mut buf) {
info!(
"got {} bytes of response data on stream {}: {:x?}",
read,
stream_id,
&buf[..read]
);
if let Some(resp) = map.remove(&stream_id) {
resp.send(Some(buf[..read].to_vec())).unwrap_or_else(|e| {
warn!("the receiver dropped {:?}", e);
});
}
}
}
Ok((_stream_id, quiche::h3::Event::Finished)) => {
debug!("quiche::h3::Event::Finished");
}
Ok((_stream_id, quiche::h3::Event::Datagram)) => {
debug!("quiche::h3::Event::Datagram");
}
Ok((_stream_id, quiche::h3::Event::GoAway)) => {
debug!("quiche::h3::Event::GoAway");
}
Err(quiche::h3::Error::Done) => {
debug!("quiche::h3::Error::Done");
break;
}
Err(e) => {
error!("HTTP/3 processing failed: {:?}", e);
break;
}
}
}
}
async fn flush_tx(sk: &UdpSocket, conn: &mut quiche::Connection) -> Result<()> {
let mut out = [0; MAX_DATAGRAM_SIZE];
loop {
let write = match conn.send(&mut out) {
Ok(v) => v,
Err(quiche::Error::Done) => {
debug!("done writing");
break;
}
Err(e) => {
conn.close(false, 0x1, b"fail").ok();
return Err(anyhow::Error::new(e));
}
};
sk.send(&out[..write]).await?;
debug!("written {}", write);
}
Ok(())
}
fn create_quiche_config(cert_path: Option<&str>) -> Result<quiche::Config> {
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
config.set_application_protos(h3::APPLICATION_PROTOCOL)?;
config.verify_peer(true);
config.load_verify_locations_from_directory(cert_path.unwrap_or(SYSTEM_CERT_PATH))?;
// Some of these configs are necessary, or the server can't respond the HTTP/3 request.
config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS);
config.set_max_udp_payload_size(MAX_DATAGRAM_SIZE_U64);
config.set_initial_max_data(MAX_INCOMING_BUFFER_SIZE_WHOLE);
config.set_initial_max_stream_data_bidi_local(MAX_INCOMING_BUFFER_SIZE_EACH);
config.set_initial_max_stream_data_bidi_remote(MAX_INCOMING_BUFFER_SIZE_EACH);
config.set_initial_max_stream_data_uni(MAX_INCOMING_BUFFER_SIZE_EACH);
config.set_initial_max_streams_bidi(MAX_CONCURRENT_STREAM_SIZE);
config.set_initial_max_streams_uni(MAX_CONCURRENT_STREAM_SIZE);
config.set_disable_active_migration(true);
Ok(config)
}
fn mark_socket(fd: RawFd, mark: u32) -> Result<()> {
// libc::setsockopt is a wrapper function calling into bionic setsockopt.
// Both fd and mark are valid, which makes the function call mostly safe.
if unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_MARK,
&mark as *const _ as *const libc::c_void,
std::mem::size_of::<u32>() as libc::socklen_t,
)
} == 0
{
Ok(())
} else {
Err(anyhow::Error::new(std::io::Error::last_os_error()))
}
}
/// Performs static initialization fo the DoH engine.
#[no_mangle]
pub extern "C" fn doh_init() -> *const c_char {
android_logger::init_once(android_logger::Config::default().with_min_level(log::Level::Trace));
static VERSION: &str = "1.0\0";
VERSION.as_ptr() as *const c_char
}
/// Creates and returns a DoH engine instance.
/// The returned object must be freed with doh_delete().
/// # Safety
/// All the pointer args are null terminated strings.
#[no_mangle]
pub unsafe extern "C" fn doh_new(
url: *const c_char,
ip_addr: *const c_char,
mark: libc::uint32_t,
cert_path: *const c_char,
) -> *mut DohDispatcher {
let (url, ip_addr, cert_path) = match (
std::ffi::CStr::from_ptr(url).to_str(),
std::ffi::CStr::from_ptr(ip_addr).to_str(),
std::ffi::CStr::from_ptr(cert_path).to_str(),
) {
(Ok(url), Ok(ip_addr), Ok(cert_path)) => {
if !cert_path.is_empty() {
(url, ip_addr, Some(cert_path))
} else {
(url, ip_addr, None)
}
}
_ => {
error!("bad input");
return ptr::null_mut();
}
};
match DohDispatcher::new(url, ip_addr, mark, cert_path) {
Ok(c) => Box::into_raw(c),
Err(e) => {
error!("doh_new: failed: {:?}", e);
ptr::null_mut()
}
}
}
/// Deletes a DoH engine created by doh_new().
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_new()`
/// and not yet deleted by `doh_delete()`.
#[no_mangle]
pub unsafe extern "C" fn doh_delete(doh: *mut DohDispatcher) {
Box::from_raw(doh).abort_handler()
}
/// Sends a DNS query and waits for the response.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_new()`
/// and not yet deleted by `doh_delete()`.
/// `query` must point to a buffer at least `query_len` in size.
/// `response` must point to a buffer at least `response_len` in size.
#[no_mangle]
pub unsafe extern "C" fn doh_query(
doh: &mut DohDispatcher,
query: *mut u8,
query_len: size_t,
response: *mut u8,
response_len: size_t,
) -> ssize_t {
let q = slice::from_raw_parts_mut(query, query_len);
let (resp_tx, resp_rx) = oneshot::channel();
let cmd = Command::DohQuery { query: q.to_vec(), resp: resp_tx };
if let Err(e) = doh.query(cmd) {
error!("Failed to send the query: {:?}", e);
return -1;
}
match RUNTIME_STATIC.block_on(resp_rx) {
Ok(value) => {
if let Some(resp) = value {
if resp.len() > response_len || resp.len() > isize::MAX as usize {
return -1;
}
let response = slice::from_raw_parts_mut(response, resp.len());
response.copy_from_slice(&resp);
return resp.len() as ssize_t;
}
-1
}
Err(e) => {
error!("no result {}", e);
-1
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
const TEST_MARK: u32 = 0xD0033;
const LOOPBACK_ADDR: &str = "127.0.0.1";
#[test]
fn dohdispatcher_invalid_args() {
let test_args = [
// Bad url
("foo", "bar"),
("https://1", "bar"),
("https:/", "bar"),
// Bad ip
("https://dns.google", "bar"),
("https://dns.google", "256.256.256.256"),
];
for args in &test_args {
assert!(
DohDispatcher::new(args.0, args.1, 0, None).is_err(),
"doh dispatcher should not be created"
)
}
}
#[test]
fn make_doh_udp_socket() {
// Bad ip
for ip in &["foo", "1", "333.333.333.333"] {
assert!(super::make_doh_udp_socket(ip, 0).is_err(), "udp socket should not be created");
}
// Make a socket connecting to loopback with a test mark.
let sk = super::make_doh_udp_socket(LOOPBACK_ADDR, TEST_MARK).unwrap();
// Check if the socket is connected to loopback.
assert_eq!(
sk.peer_addr().unwrap(),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
);
// Check if the socket mark is correct.
let fd: RawFd = sk.as_raw_fd();
let mut mark: u32 = 50;
let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
unsafe {
// Safety: fd must be valid.
assert_eq!(
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_MARK,
&mut mark as *mut _ as *mut libc::c_void,
&mut size as *mut _ as *mut libc::socklen_t,
),
0
);
}
assert_eq!(mark, TEST_MARK);
// Check if the socket is non-blocking.
unsafe {
// Safety: fd must be valid.
assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
}
}
#[test]
fn create_quiche_config() {
assert!(
super::create_quiche_config(None).is_ok(),
"quiche config without cert creating failed"
);
assert!(
super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
"quiche config with cert creating failed"
);
}
const GOOGLE_DNS_URL: &str = "https://dns.google/dns-query";
const GOOGLE_DNS_IP: &str = "8.8.8.8";
// qtype: A, qname: www.example.com
const SAMPLE_QUERY: &str = "q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB";
#[test]
fn close_doh() {
let udp_sk = super::make_doh_udp_socket(LOOPBACK_ADDR, TEST_MARK).unwrap();
let doh =
DohDispatcher::new_with_socket(GOOGLE_DNS_URL, GOOGLE_DNS_IP, 0, None, udp_sk).unwrap();
let (resp_tx, resp_rx) = oneshot::channel();
let cmd = Command::DohQuery { query: SAMPLE_QUERY.as_bytes().to_vec(), resp: resp_tx };
assert!(doh.query(cmd).is_ok(), "Send query failed");
doh.abort_handler();
assert!(RUNTIME_STATIC.block_on(resp_rx).is_err(), "channel should already be closed");
}
#[test]
fn doh_init() {
unsafe {
// Safety: the returned pointer of doh_init() must be a null terminated string.
assert_eq!(std::ffi::CStr::from_ptr(super::doh_init()).to_str().unwrap(), "1.0");
}
}
}