/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "resolv" #include "DnsTlsQueryMap.h" #include namespace android { namespace net { std::unique_ptr DnsTlsQueryMap::recordQuery( const netdutils::Slice query) { std::lock_guard guard(mLock); // Store the query so it can be matched to the response or reissued. if (query.size() < 2) { LOG(WARNING) << "Query is too short"; return nullptr; } int32_t newId = getFreeId(); if (newId < 0) { LOG(WARNING) << "All query IDs are in use"; return nullptr; } // Make a copy of the query. std::vector tmp(query.base(), query.base() + query.size()); Query q = {.newId = static_cast(newId), .query = std::move(tmp)}; const auto [it, inserted] = mQueries.try_emplace(newId, q); if (!inserted) { LOG(ERROR) << "Failed to store pending query"; return nullptr; } return std::make_unique(q, it->second.result.get_future()); } void DnsTlsQueryMap::expire(QueryPromise* p) { Result r = { .code = Response::network_error }; p->result.set_value(r); } void DnsTlsQueryMap::markTried(uint16_t newId) { std::lock_guard guard(mLock); auto it = mQueries.find(newId); if (it != mQueries.end()) { it->second.tries++; } } void DnsTlsQueryMap::cleanup() { std::lock_guard guard(mLock); for (auto it = mQueries.begin(); it != mQueries.end();) { auto& p = it->second; if (p.tries >= kMaxTries) { expire(&p); it = mQueries.erase(it); } else { ++it; } } } int32_t DnsTlsQueryMap::getFreeId() { if (mQueries.empty()) { return 0; } uint16_t maxId = mQueries.rbegin()->first; if (maxId < UINT16_MAX) { return maxId + 1; } if (mQueries.size() == UINT16_MAX + 1) { // Map is full. return -1; } // Linear scan. uint16_t nextId = 0; for (auto& pair : mQueries) { uint16_t id = pair.first; if (id != nextId) { // Found a gap. return nextId; } nextId = id + 1; } // Unreachable (but the compiler isn't smart enough to prove it). return -1; } std::vector DnsTlsQueryMap::getAll() { std::lock_guard guard(mLock); std::vector queries; for (auto& q : mQueries) { queries.push_back(q.second.query); } return queries; } bool DnsTlsQueryMap::empty() { std::lock_guard guard(mLock); return mQueries.empty(); } void DnsTlsQueryMap::clear() { std::lock_guard guard(mLock); for (auto& q : mQueries) { expire(&q.second); } mQueries.clear(); } void DnsTlsQueryMap::onResponse(std::vector response) { LOG(VERBOSE) << "Got response of size " << response.size(); if (response.size() < 2) { LOG(WARNING) << "Response is too short"; return; } uint16_t id = response[0] << 8 | response[1]; std::lock_guard guard(mLock); auto it = mQueries.find(id); if (it == mQueries.end()) { LOG(WARNING) << "Discarding response: unknown ID " << id; return; } Result r = { .code = Response::success, .response = std::move(response) }; // Rewrite ID to match the query const uint8_t* data = it->second.query.query.data(); r.response[0] = data[0]; r.response[1] = data[1]; LOG(DEBUG) << "Sending result to dispatcher"; it->second.result.set_value(std::move(r)); mQueries.erase(it); } } // end of namespace net } // end of namespace android