1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include "DnsTlsQueryMap.h"
20 
21 #include <android-base/logging.h>
22 
23 namespace android {
24 namespace net {
25 
recordQuery(const netdutils::Slice query)26 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
27         const netdutils::Slice query) {
28     std::lock_guard guard(mLock);
29 
30     // Store the query so it can be matched to the response or reissued.
31     if (query.size() < 2) {
32         LOG(WARNING) << "Query is too short";
33         return nullptr;
34     }
35     int32_t newId = getFreeId();
36     if (newId < 0) {
37         LOG(WARNING) << "All query IDs are in use";
38         return nullptr;
39     }
40 
41     // Make a copy of the query.
42     std::vector<uint8_t> tmp(query.base(), query.base() + query.size());
43     Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)};
44 
45     const auto [it, inserted] = mQueries.try_emplace(newId, q);
46     if (!inserted) {
47         LOG(ERROR) << "Failed to store pending query";
48         return nullptr;
49     }
50     return std::make_unique<QueryFuture>(q, it->second.result.get_future());
51 }
52 
expire(QueryPromise * p)53 void DnsTlsQueryMap::expire(QueryPromise* p) {
54     Result r = { .code = Response::network_error };
55     p->result.set_value(r);
56 }
57 
markTried(uint16_t newId)58 void DnsTlsQueryMap::markTried(uint16_t newId) {
59     std::lock_guard guard(mLock);
60     auto it = mQueries.find(newId);
61     if (it != mQueries.end()) {
62         it->second.tries++;
63     }
64 }
65 
cleanup()66 void DnsTlsQueryMap::cleanup() {
67     std::lock_guard guard(mLock);
68     for (auto it = mQueries.begin(); it != mQueries.end();) {
69         auto& p = it->second;
70         if (p.tries >= kMaxTries) {
71             expire(&p);
72             it = mQueries.erase(it);
73         } else {
74             ++it;
75         }
76     }
77 }
78 
getFreeId()79 int32_t DnsTlsQueryMap::getFreeId() {
80     if (mQueries.empty()) {
81         return 0;
82     }
83     uint16_t maxId = mQueries.rbegin()->first;
84     if (maxId < UINT16_MAX) {
85         return maxId + 1;
86     }
87     if (mQueries.size() == UINT16_MAX + 1) {
88         // Map is full.
89         return -1;
90     }
91     // Linear scan.
92     uint16_t nextId = 0;
93     for (auto& pair : mQueries) {
94         uint16_t id = pair.first;
95         if (id != nextId) {
96             // Found a gap.
97             return nextId;
98         }
99         nextId = id + 1;
100     }
101     // Unreachable (but the compiler isn't smart enough to prove it).
102     return -1;
103 }
104 
getAll()105 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
106     std::lock_guard guard(mLock);
107     std::vector<DnsTlsQueryMap::Query> queries;
108     for (auto& q : mQueries) {
109         queries.push_back(q.second.query);
110     }
111     return queries;
112 }
113 
empty()114 bool DnsTlsQueryMap::empty() {
115     std::lock_guard guard(mLock);
116     return mQueries.empty();
117 }
118 
clear()119 void DnsTlsQueryMap::clear() {
120     std::lock_guard guard(mLock);
121     for (auto& q : mQueries) {
122         expire(&q.second);
123     }
124     mQueries.clear();
125 }
126 
onResponse(std::vector<uint8_t> response)127 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
128     LOG(VERBOSE) << "Got response of size " << response.size();
129     if (response.size() < 2) {
130         LOG(WARNING) << "Response is too short";
131         return;
132     }
133     uint16_t id = response[0] << 8 | response[1];
134     std::lock_guard guard(mLock);
135     auto it = mQueries.find(id);
136     if (it == mQueries.end()) {
137         LOG(WARNING) << "Discarding response: unknown ID " << id;
138         return;
139     }
140     Result r = { .code = Response::success, .response = std::move(response) };
141     // Rewrite ID to match the query
142     const uint8_t* data = it->second.query.query.data();
143     r.response[0] = data[0];
144     r.response[1] = data[1];
145     LOG(DEBUG) << "Sending result to dispatcher";
146     it->second.result.set_value(std::move(r));
147     mQueries.erase(it);
148 }
149 
150 }  // end of namespace net
151 }  // end of namespace android
152