1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "resolv" 18 19 #include "DnsTlsQueryMap.h" 20 21 #include <android-base/logging.h> 22 23 #include "Experiments.h" 24 25 namespace android { 26 namespace net { 27 28 DnsTlsQueryMap::DnsTlsQueryMap() { 29 mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries); 30 if (mMaxTries < 1) mMaxTries = 1; 31 } 32 33 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery( 34 const netdutils::Slice query) { 35 std::lock_guard guard(mLock); 36 37 // Store the query so it can be matched to the response or reissued. 38 if (query.size() < 2) { 39 LOG(WARNING) << "Query is too short"; 40 return nullptr; 41 } 42 int32_t newId = getFreeId(); 43 if (newId < 0) { 44 LOG(WARNING) << "All query IDs are in use"; 45 return nullptr; 46 } 47 48 // Make a copy of the query. 49 std::vector<uint8_t> tmp(query.base(), query.base() + query.size()); 50 Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)}; 51 52 const auto [it, inserted] = mQueries.try_emplace(newId, q); 53 if (!inserted) { 54 LOG(ERROR) << "Failed to store pending query"; 55 return nullptr; 56 } 57 return std::make_unique<QueryFuture>(q, it->second.result.get_future()); 58 } 59 60 void DnsTlsQueryMap::expire(QueryPromise* p) { 61 Result r = { .code = Response::network_error }; 62 p->result.set_value(r); 63 } 64 65 void DnsTlsQueryMap::markTried(uint16_t newId) { 66 std::lock_guard guard(mLock); 67 auto it = mQueries.find(newId); 68 if (it != mQueries.end()) { 69 it->second.tries++; 70 } 71 } 72 73 void DnsTlsQueryMap::cleanup() { 74 std::lock_guard guard(mLock); 75 for (auto it = mQueries.begin(); it != mQueries.end();) { 76 auto& p = it->second; 77 if (p.tries >= mMaxTries) { 78 expire(&p); 79 it = mQueries.erase(it); 80 } else { 81 ++it; 82 } 83 } 84 } 85 86 int32_t DnsTlsQueryMap::getFreeId() { 87 if (mQueries.empty()) { 88 return 0; 89 } 90 uint16_t maxId = mQueries.rbegin()->first; 91 if (maxId < UINT16_MAX) { 92 return maxId + 1; 93 } 94 if (mQueries.size() == UINT16_MAX + 1) { 95 // Map is full. 96 return -1; 97 } 98 // Linear scan. 99 uint16_t nextId = 0; 100 for (auto& pair : mQueries) { 101 uint16_t id = pair.first; 102 if (id != nextId) { 103 // Found a gap. 104 return nextId; 105 } 106 nextId = id + 1; 107 } 108 // Unreachable (but the compiler isn't smart enough to prove it). 109 return -1; 110 } 111 112 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() { 113 std::lock_guard guard(mLock); 114 std::vector<DnsTlsQueryMap::Query> queries; 115 for (auto& q : mQueries) { 116 queries.push_back(q.second.query); 117 } 118 return queries; 119 } 120 121 bool DnsTlsQueryMap::empty() { 122 std::lock_guard guard(mLock); 123 return mQueries.empty(); 124 } 125 126 void DnsTlsQueryMap::clear() { 127 std::lock_guard guard(mLock); 128 for (auto& q : mQueries) { 129 expire(&q.second); 130 } 131 mQueries.clear(); 132 } 133 134 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) { 135 LOG(VERBOSE) << "Got response of size " << response.size(); 136 if (response.size() < 2) { 137 LOG(WARNING) << "Response is too short"; 138 return; 139 } 140 uint16_t id = response[0] << 8 | response[1]; 141 std::lock_guard guard(mLock); 142 auto it = mQueries.find(id); 143 if (it == mQueries.end()) { 144 LOG(WARNING) << "Discarding response: unknown ID " << id; 145 return; 146 } 147 Result r = { .code = Response::success, .response = std::move(response) }; 148 // Rewrite ID to match the query 149 const uint8_t* data = it->second.query.query.data(); 150 r.response[0] = data[0]; 151 r.response[1] = data[1]; 152 LOG(DEBUG) << "Sending result to dispatcher"; 153 it->second.result.set_value(std::move(r)); 154 mQueries.erase(it); 155 } 156 157 } // end of namespace net 158 } // end of namespace android 159