1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include "DnsTlsQueryMap.h"
20 
21 #include <android-base/logging.h>
22 
23 #include "Experiments.h"
24 
25 namespace android {
26 namespace net {
27 
28 DnsTlsQueryMap::DnsTlsQueryMap() {
29     mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries);
30     if (mMaxTries < 1) mMaxTries = 1;
31 }
32 
33 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
34         const netdutils::Slice query) {
35     std::lock_guard guard(mLock);
36 
37     // Store the query so it can be matched to the response or reissued.
38     if (query.size() < 2) {
39         LOG(WARNING) << "Query is too short";
40         return nullptr;
41     }
42     int32_t newId = getFreeId();
43     if (newId < 0) {
44         LOG(WARNING) << "All query IDs are in use";
45         return nullptr;
46     }
47 
48     // Make a copy of the query.
49     std::vector<uint8_t> tmp(query.base(), query.base() + query.size());
50     Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)};
51 
52     const auto [it, inserted] = mQueries.try_emplace(newId, q);
53     if (!inserted) {
54         LOG(ERROR) << "Failed to store pending query";
55         return nullptr;
56     }
57     return std::make_unique<QueryFuture>(q, it->second.result.get_future());
58 }
59 
60 void DnsTlsQueryMap::expire(QueryPromise* p) {
61     Result r = { .code = Response::network_error };
62     p->result.set_value(r);
63 }
64 
65 void DnsTlsQueryMap::markTried(uint16_t newId) {
66     std::lock_guard guard(mLock);
67     auto it = mQueries.find(newId);
68     if (it != mQueries.end()) {
69         it->second.tries++;
70     }
71 }
72 
73 void DnsTlsQueryMap::cleanup() {
74     std::lock_guard guard(mLock);
75     for (auto it = mQueries.begin(); it != mQueries.end();) {
76         auto& p = it->second;
77         if (p.tries >= mMaxTries) {
78             expire(&p);
79             it = mQueries.erase(it);
80         } else {
81             ++it;
82         }
83     }
84 }
85 
86 int32_t DnsTlsQueryMap::getFreeId() {
87     if (mQueries.empty()) {
88         return 0;
89     }
90     uint16_t maxId = mQueries.rbegin()->first;
91     if (maxId < UINT16_MAX) {
92         return maxId + 1;
93     }
94     if (mQueries.size() == UINT16_MAX + 1) {
95         // Map is full.
96         return -1;
97     }
98     // Linear scan.
99     uint16_t nextId = 0;
100     for (auto& pair : mQueries) {
101         uint16_t id = pair.first;
102         if (id != nextId) {
103             // Found a gap.
104             return nextId;
105         }
106         nextId = id + 1;
107     }
108     // Unreachable (but the compiler isn't smart enough to prove it).
109     return -1;
110 }
111 
112 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
113     std::lock_guard guard(mLock);
114     std::vector<DnsTlsQueryMap::Query> queries;
115     for (auto& q : mQueries) {
116         queries.push_back(q.second.query);
117     }
118     return queries;
119 }
120 
121 bool DnsTlsQueryMap::empty() {
122     std::lock_guard guard(mLock);
123     return mQueries.empty();
124 }
125 
126 void DnsTlsQueryMap::clear() {
127     std::lock_guard guard(mLock);
128     for (auto& q : mQueries) {
129         expire(&q.second);
130     }
131     mQueries.clear();
132 }
133 
134 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
135     LOG(VERBOSE) << "Got response of size " << response.size();
136     if (response.size() < 2) {
137         LOG(WARNING) << "Response is too short";
138         return;
139     }
140     uint16_t id = response[0] << 8 | response[1];
141     std::lock_guard guard(mLock);
142     auto it = mQueries.find(id);
143     if (it == mQueries.end()) {
144         LOG(WARNING) << "Discarding response: unknown ID " << id;
145         return;
146     }
147     Result r = { .code = Response::success, .response = std::move(response) };
148     // Rewrite ID to match the query
149     const uint8_t* data = it->second.query.query.data();
150     r.response[0] = data[0];
151     r.response[1] = data[1];
152     LOG(DEBUG) << "Sending result to dispatcher";
153     it->second.result.set_value(std::move(r));
154     mQueries.erase(it);
155 }
156 
157 }  // end of namespace net
158 }  // end of namespace android
159