1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "DnsTlsDispatcher"
18 //#define LOG_NDEBUG 0
19 
20 #include "dns/DnsTlsDispatcher.h"
21 
22 #include "log/log.h"
23 
24 namespace android {
25 namespace net {
26 
27 using netdutils::Slice;
28 
29 // static
30 std::mutex DnsTlsDispatcher::sLock;
31 
getOrderedServerList(const std::list<DnsTlsServer> & tlsServers,unsigned mark) const32 std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(
33         const std::list<DnsTlsServer> &tlsServers, unsigned mark) const {
34     // Our preferred DnsTlsServer order is:
35     //     1) reuse existing IPv6 connections
36     //     2) reuse existing IPv4 connections
37     //     3) establish new IPv6 connections
38     //     4) establish new IPv4 connections
39     std::list<DnsTlsServer> existing6;
40     std::list<DnsTlsServer> existing4;
41     std::list<DnsTlsServer> new6;
42     std::list<DnsTlsServer> new4;
43 
44     // Pull out any servers for which we might have existing connections and
45     // place them at the from the list of servers to try.
46     {
47         std::lock_guard<std::mutex> guard(sLock);
48 
49         for (const auto& tlsServer : tlsServers) {
50             const Key key = std::make_pair(mark, tlsServer);
51             if (mStore.find(key) != mStore.end()) {
52                 switch (tlsServer.ss.ss_family) {
53                     case AF_INET:
54                         existing4.push_back(tlsServer);
55                         break;
56                     case AF_INET6:
57                         existing6.push_back(tlsServer);
58                         break;
59                 }
60             } else {
61                 switch (tlsServer.ss.ss_family) {
62                     case AF_INET:
63                         new4.push_back(tlsServer);
64                         break;
65                     case AF_INET6:
66                         new6.push_back(tlsServer);
67                         break;
68                 }
69             }
70         }
71     }
72 
73     auto& out = existing6;
74     out.splice(out.cend(), existing4);
75     out.splice(out.cend(), new6);
76     out.splice(out.cend(), new4);
77     return out;
78 }
79 
query(const std::list<DnsTlsServer> & tlsServers,unsigned mark,const Slice query,const Slice ans,int * resplen)80 DnsTlsTransport::Response DnsTlsDispatcher::query(
81         const std::list<DnsTlsServer> &tlsServers, unsigned mark,
82         const Slice query, const Slice ans, int *resplen) {
83     const std::list<DnsTlsServer> orderedServers(getOrderedServerList(tlsServers, mark));
84 
85     if (orderedServers.empty()) ALOGW("Empty DnsTlsServer list");
86 
87     DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error;
88     for (const auto& server : orderedServers) {
89         code = this->query(server, mark, query, ans, resplen);
90         switch (code) {
91             // These response codes are valid responses and not expected to
92             // change if another server is queried.
93             case DnsTlsTransport::Response::success:
94             case DnsTlsTransport::Response::limit_error:
95                 return code;
96                 break;
97             // These response codes might differ when trying other servers, so
98             // keep iterating to see if we can get a different (better) result.
99             case DnsTlsTransport::Response::network_error:
100             case DnsTlsTransport::Response::internal_error:
101                 continue;
102                 break;
103             // No "default" statement.
104         }
105     }
106 
107     return code;
108 }
109 
query(const DnsTlsServer & server,unsigned mark,const Slice query,const Slice ans,int * resplen)110 DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned mark,
111                                                   const Slice query,
112                                                   const Slice ans, int *resplen) {
113     const Key key = std::make_pair(mark, server);
114     Transport* xport;
115     {
116         std::lock_guard<std::mutex> guard(sLock);
117         auto it = mStore.find(key);
118         if (it == mStore.end()) {
119             xport = new Transport(server, mark, mFactory.get());
120             mStore[key].reset(xport);
121         } else {
122             xport = it->second.get();
123         }
124         ++xport->useCount;
125     }
126 
127     ALOGV("Sending query of length %zu", query.size());
128     auto res = xport->transport.query(query);
129     ALOGV("Awaiting response");
130     const auto& result = res.get();
131     DnsTlsTransport::Response code = result.code;
132     if (code == DnsTlsTransport::Response::success) {
133         if (result.response.size() > ans.size()) {
134             ALOGV("Response too large: %zu > %zu", result.response.size(), ans.size());
135             code = DnsTlsTransport::Response::limit_error;
136         } else {
137             ALOGV("Got response successfully");
138             *resplen = result.response.size();
139             netdutils::copy(ans, netdutils::makeSlice(result.response));
140         }
141     } else {
142         ALOGV("Query failed: %u", (unsigned int) code);
143     }
144 
145     auto now = std::chrono::steady_clock::now();
146     {
147         std::lock_guard<std::mutex> guard(sLock);
148         --xport->useCount;
149         xport->lastUsed = now;
150         cleanup(now);
151     }
152     return code;
153 }
154 
155 // This timeout effectively controls how long to keep SSL session tickets.
156 static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
cleanup(std::chrono::time_point<std::chrono::steady_clock> now)157 void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
158     // To avoid scanning mStore after every query, return early if a cleanup has been
159     // performed recently.
160     if (now - mLastCleanup < IDLE_TIMEOUT) {
161         return;
162     }
163     for (auto it = mStore.begin(); it != mStore.end();) {
164         auto& s = it->second;
165         if (s->useCount == 0 && now - s->lastUsed > IDLE_TIMEOUT) {
166             it = mStore.erase(it);
167         } else {
168             ++it;
169         }
170     }
171     mLastCleanup = now;
172 }
173 
174 }  // end of namespace net
175 }  // end of namespace android
176