1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "resolv"
18
19 #include "DnsTlsQueryMap.h"
20
21 #include <android-base/logging.h>
22
23 #include "Experiments.h"
24
25 namespace android {
26 namespace net {
27
DnsTlsQueryMap()28 DnsTlsQueryMap::DnsTlsQueryMap() {
29 mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries);
30 if (mMaxTries < 1) mMaxTries = 1;
31 }
32
recordQuery(const netdutils::Slice query)33 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
34 const netdutils::Slice query) {
35 std::lock_guard guard(mLock);
36
37 // Store the query so it can be matched to the response or reissued.
38 if (query.size() < 2) {
39 LOG(WARNING) << "Query is too short";
40 return nullptr;
41 }
42 int32_t newId = getFreeId();
43 if (newId < 0) {
44 LOG(WARNING) << "All query IDs are in use";
45 return nullptr;
46 }
47
48 // Make a copy of the query.
49 std::vector<uint8_t> tmp(query.base(), query.base() + query.size());
50 Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)};
51
52 const auto [it, inserted] = mQueries.try_emplace(newId, q);
53 if (!inserted) {
54 LOG(ERROR) << "Failed to store pending query";
55 return nullptr;
56 }
57 return std::make_unique<QueryFuture>(q, it->second.result.get_future());
58 }
59
expire(QueryPromise * p)60 void DnsTlsQueryMap::expire(QueryPromise* p) {
61 Result r = { .code = Response::network_error };
62 p->result.set_value(r);
63 }
64
markTried(uint16_t newId)65 void DnsTlsQueryMap::markTried(uint16_t newId) {
66 std::lock_guard guard(mLock);
67 auto it = mQueries.find(newId);
68 if (it != mQueries.end()) {
69 it->second.tries++;
70 }
71 }
72
cleanup()73 void DnsTlsQueryMap::cleanup() {
74 std::lock_guard guard(mLock);
75 for (auto it = mQueries.begin(); it != mQueries.end();) {
76 auto& p = it->second;
77 if (p.tries >= mMaxTries) {
78 expire(&p);
79 it = mQueries.erase(it);
80 } else {
81 ++it;
82 }
83 }
84 }
85
getFreeId()86 int32_t DnsTlsQueryMap::getFreeId() {
87 if (mQueries.empty()) {
88 return 0;
89 }
90 uint16_t maxId = mQueries.rbegin()->first;
91 if (maxId < UINT16_MAX) {
92 return maxId + 1;
93 }
94 if (mQueries.size() == UINT16_MAX + 1) {
95 // Map is full.
96 return -1;
97 }
98 // Linear scan.
99 uint16_t nextId = 0;
100 for (auto& pair : mQueries) {
101 uint16_t id = pair.first;
102 if (id != nextId) {
103 // Found a gap.
104 return nextId;
105 }
106 nextId = id + 1;
107 }
108 // Unreachable (but the compiler isn't smart enough to prove it).
109 return -1;
110 }
111
getAll()112 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
113 std::lock_guard guard(mLock);
114 std::vector<DnsTlsQueryMap::Query> queries;
115 queries.reserve(mQueries.size());
116 for (auto& q : mQueries) {
117 queries.push_back(q.second.query);
118 }
119 return queries;
120 }
121
empty()122 bool DnsTlsQueryMap::empty() {
123 std::lock_guard guard(mLock);
124 return mQueries.empty();
125 }
126
clear()127 void DnsTlsQueryMap::clear() {
128 std::lock_guard guard(mLock);
129 for (auto& q : mQueries) {
130 expire(&q.second);
131 }
132 mQueries.clear();
133 }
134
onResponse(std::vector<uint8_t> response)135 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
136 LOG(VERBOSE) << "Got response of size " << response.size();
137 if (response.size() < 2) {
138 LOG(WARNING) << "Response is too short";
139 return;
140 }
141 uint16_t id = response[0] << 8 | response[1];
142 std::lock_guard guard(mLock);
143 auto it = mQueries.find(id);
144 if (it == mQueries.end()) {
145 LOG(WARNING) << "Discarding response: unknown ID " << id;
146 return;
147 }
148 Result r = { .code = Response::success, .response = std::move(response) };
149 // Rewrite ID to match the query
150 const uint8_t* data = it->second.query.query.data();
151 r.response[0] = data[0];
152 r.response[1] = data[1];
153 LOG(DEBUG) << "Sending result to dispatcher";
154 it->second.result.set_value(std::move(r));
155 mQueries.erase(it);
156 }
157
158 } // end of namespace net
159 } // end of namespace android
160