aosp12/packages/modules/DnsResolver/DnsTlsQueryMap.cpp

159 lines
4.4 KiB
C++
Raw Normal View History

2023-01-09 17:11:35 +08:00
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "resolv"
#include "DnsTlsQueryMap.h"
#include <android-base/logging.h>
#include "Experiments.h"
namespace android {
namespace net {
DnsTlsQueryMap::DnsTlsQueryMap() {
mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries);
if (mMaxTries < 1) mMaxTries = 1;
}
std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
const netdutils::Slice query) {
std::lock_guard guard(mLock);
// Store the query so it can be matched to the response or reissued.
if (query.size() < 2) {
LOG(WARNING) << "Query is too short";
return nullptr;
}
int32_t newId = getFreeId();
if (newId < 0) {
LOG(WARNING) << "All query IDs are in use";
return nullptr;
}
// Make a copy of the query.
std::vector<uint8_t> tmp(query.base(), query.base() + query.size());
Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)};
const auto [it, inserted] = mQueries.try_emplace(newId, q);
if (!inserted) {
LOG(ERROR) << "Failed to store pending query";
return nullptr;
}
return std::make_unique<QueryFuture>(q, it->second.result.get_future());
}
void DnsTlsQueryMap::expire(QueryPromise* p) {
Result r = { .code = Response::network_error };
p->result.set_value(r);
}
void DnsTlsQueryMap::markTried(uint16_t newId) {
std::lock_guard guard(mLock);
auto it = mQueries.find(newId);
if (it != mQueries.end()) {
it->second.tries++;
}
}
void DnsTlsQueryMap::cleanup() {
std::lock_guard guard(mLock);
for (auto it = mQueries.begin(); it != mQueries.end();) {
auto& p = it->second;
if (p.tries >= mMaxTries) {
expire(&p);
it = mQueries.erase(it);
} else {
++it;
}
}
}
int32_t DnsTlsQueryMap::getFreeId() {
if (mQueries.empty()) {
return 0;
}
uint16_t maxId = mQueries.rbegin()->first;
if (maxId < UINT16_MAX) {
return maxId + 1;
}
if (mQueries.size() == UINT16_MAX + 1) {
// Map is full.
return -1;
}
// Linear scan.
uint16_t nextId = 0;
for (auto& pair : mQueries) {
uint16_t id = pair.first;
if (id != nextId) {
// Found a gap.
return nextId;
}
nextId = id + 1;
}
// Unreachable (but the compiler isn't smart enough to prove it).
return -1;
}
std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
std::lock_guard guard(mLock);
std::vector<DnsTlsQueryMap::Query> queries;
for (auto& q : mQueries) {
queries.push_back(q.second.query);
}
return queries;
}
bool DnsTlsQueryMap::empty() {
std::lock_guard guard(mLock);
return mQueries.empty();
}
void DnsTlsQueryMap::clear() {
std::lock_guard guard(mLock);
for (auto& q : mQueries) {
expire(&q.second);
}
mQueries.clear();
}
void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
LOG(VERBOSE) << "Got response of size " << response.size();
if (response.size() < 2) {
LOG(WARNING) << "Response is too short";
return;
}
uint16_t id = response[0] << 8 | response[1];
std::lock_guard guard(mLock);
auto it = mQueries.find(id);
if (it == mQueries.end()) {
LOG(WARNING) << "Discarding response: unknown ID " << id;
return;
}
Result r = { .code = Response::success, .response = std::move(response) };
// Rewrite ID to match the query
const uint8_t* data = it->second.query.query.data();
r.response[0] = data[0];
r.response[1] = data[1];
LOG(DEBUG) << "Sending result to dispatcher";
it->second.result.set_value(std::move(r));
mQueries.erase(it);
}
} // end of namespace net
} // end of namespace android