159 lines
4.4 KiB
C++
159 lines
4.4 KiB
C++
|
/*
|
||
|
* Copyright (C) 2018 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#define LOG_TAG "resolv"
|
||
|
|
||
|
#include "DnsTlsQueryMap.h"
|
||
|
|
||
|
#include <android-base/logging.h>
|
||
|
|
||
|
#include "Experiments.h"
|
||
|
|
||
|
namespace android {
|
||
|
namespace net {
|
||
|
|
||
|
DnsTlsQueryMap::DnsTlsQueryMap() {
|
||
|
mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries);
|
||
|
if (mMaxTries < 1) mMaxTries = 1;
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
|
||
|
const netdutils::Slice query) {
|
||
|
std::lock_guard guard(mLock);
|
||
|
|
||
|
// Store the query so it can be matched to the response or reissued.
|
||
|
if (query.size() < 2) {
|
||
|
LOG(WARNING) << "Query is too short";
|
||
|
return nullptr;
|
||
|
}
|
||
|
int32_t newId = getFreeId();
|
||
|
if (newId < 0) {
|
||
|
LOG(WARNING) << "All query IDs are in use";
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
// Make a copy of the query.
|
||
|
std::vector<uint8_t> tmp(query.base(), query.base() + query.size());
|
||
|
Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)};
|
||
|
|
||
|
const auto [it, inserted] = mQueries.try_emplace(newId, q);
|
||
|
if (!inserted) {
|
||
|
LOG(ERROR) << "Failed to store pending query";
|
||
|
return nullptr;
|
||
|
}
|
||
|
return std::make_unique<QueryFuture>(q, it->second.result.get_future());
|
||
|
}
|
||
|
|
||
|
void DnsTlsQueryMap::expire(QueryPromise* p) {
|
||
|
Result r = { .code = Response::network_error };
|
||
|
p->result.set_value(r);
|
||
|
}
|
||
|
|
||
|
void DnsTlsQueryMap::markTried(uint16_t newId) {
|
||
|
std::lock_guard guard(mLock);
|
||
|
auto it = mQueries.find(newId);
|
||
|
if (it != mQueries.end()) {
|
||
|
it->second.tries++;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void DnsTlsQueryMap::cleanup() {
|
||
|
std::lock_guard guard(mLock);
|
||
|
for (auto it = mQueries.begin(); it != mQueries.end();) {
|
||
|
auto& p = it->second;
|
||
|
if (p.tries >= mMaxTries) {
|
||
|
expire(&p);
|
||
|
it = mQueries.erase(it);
|
||
|
} else {
|
||
|
++it;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
int32_t DnsTlsQueryMap::getFreeId() {
|
||
|
if (mQueries.empty()) {
|
||
|
return 0;
|
||
|
}
|
||
|
uint16_t maxId = mQueries.rbegin()->first;
|
||
|
if (maxId < UINT16_MAX) {
|
||
|
return maxId + 1;
|
||
|
}
|
||
|
if (mQueries.size() == UINT16_MAX + 1) {
|
||
|
// Map is full.
|
||
|
return -1;
|
||
|
}
|
||
|
// Linear scan.
|
||
|
uint16_t nextId = 0;
|
||
|
for (auto& pair : mQueries) {
|
||
|
uint16_t id = pair.first;
|
||
|
if (id != nextId) {
|
||
|
// Found a gap.
|
||
|
return nextId;
|
||
|
}
|
||
|
nextId = id + 1;
|
||
|
}
|
||
|
// Unreachable (but the compiler isn't smart enough to prove it).
|
||
|
return -1;
|
||
|
}
|
||
|
|
||
|
std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
|
||
|
std::lock_guard guard(mLock);
|
||
|
std::vector<DnsTlsQueryMap::Query> queries;
|
||
|
for (auto& q : mQueries) {
|
||
|
queries.push_back(q.second.query);
|
||
|
}
|
||
|
return queries;
|
||
|
}
|
||
|
|
||
|
bool DnsTlsQueryMap::empty() {
|
||
|
std::lock_guard guard(mLock);
|
||
|
return mQueries.empty();
|
||
|
}
|
||
|
|
||
|
void DnsTlsQueryMap::clear() {
|
||
|
std::lock_guard guard(mLock);
|
||
|
for (auto& q : mQueries) {
|
||
|
expire(&q.second);
|
||
|
}
|
||
|
mQueries.clear();
|
||
|
}
|
||
|
|
||
|
void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
|
||
|
LOG(VERBOSE) << "Got response of size " << response.size();
|
||
|
if (response.size() < 2) {
|
||
|
LOG(WARNING) << "Response is too short";
|
||
|
return;
|
||
|
}
|
||
|
uint16_t id = response[0] << 8 | response[1];
|
||
|
std::lock_guard guard(mLock);
|
||
|
auto it = mQueries.find(id);
|
||
|
if (it == mQueries.end()) {
|
||
|
LOG(WARNING) << "Discarding response: unknown ID " << id;
|
||
|
return;
|
||
|
}
|
||
|
Result r = { .code = Response::success, .response = std::move(response) };
|
||
|
// Rewrite ID to match the query
|
||
|
const uint8_t* data = it->second.query.query.data();
|
||
|
r.response[0] = data[0];
|
||
|
r.response[1] = data[1];
|
||
|
LOG(DEBUG) << "Sending result to dispatcher";
|
||
|
it->second.result.set_value(std::move(r));
|
||
|
mQueries.erase(it);
|
||
|
}
|
||
|
|
||
|
} // end of namespace net
|
||
|
} // end of namespace android
|