1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include <arpa/inet.h>
20 
21 #include <chrono>
22 
23 #include <android-base/logging.h>
24 #include <android-base/macros.h>
25 #include <gtest/gtest.h>
26 #include <netdutils/Slice.h>
27 
28 #include "DnsTlsDispatcher.h"
29 #include "DnsTlsQueryMap.h"
30 #include "DnsTlsServer.h"
31 #include "DnsTlsSessionCache.h"
32 #include "DnsTlsSocket.h"
33 #include "DnsTlsTransport.h"
34 #include "IDnsTlsSocket.h"
35 #include "IDnsTlsSocketFactory.h"
36 #include "IDnsTlsSocketObserver.h"
37 #include "tests/dns_responder/dns_tls_frontend.h"
38 
39 namespace android {
40 namespace net {
41 
42 using netdutils::Slice;
43 using netdutils::makeSlice;
44 
45 typedef std::vector<uint8_t> bytevec;
46 
parseServer(const char * server,in_port_t port,sockaddr_storage * parsed)47 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
48     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
49     if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
50         // IPv4 parse succeeded, so it's IPv4
51         sin->sin_family = AF_INET;
52         sin->sin_port = htons(port);
53         return;
54     }
55     sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
56     if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
57         // IPv6 parse succeeded, so it's IPv6.
58         sin6->sin6_family = AF_INET6;
59         sin6->sin6_port = htons(port);
60         return;
61     }
62     LOG(ERROR) << "Failed to parse server address: " << server;
63 }
64 
65 std::string SERVERNAME1 = "dns.example.com";
66 std::string SERVERNAME2 = "dns.example.org";
67 
68 // BaseTest just provides constants that are useful for the tests.
69 class BaseTest : public ::testing::Test {
70   protected:
BaseTest()71     BaseTest() {
72         parseServer("192.0.2.1", 853, &V4ADDR1);
73         parseServer("192.0.2.2", 853, &V4ADDR2);
74         parseServer("2001:db8::1", 853, &V6ADDR1);
75         parseServer("2001:db8::2", 853, &V6ADDR2);
76 
77         SERVER1 = DnsTlsServer(V4ADDR1);
78         SERVER1.name = SERVERNAME1;
79     }
80 
81     sockaddr_storage V4ADDR1;
82     sockaddr_storage V4ADDR2;
83     sockaddr_storage V6ADDR1;
84     sockaddr_storage V6ADDR2;
85 
86     DnsTlsServer SERVER1;
87 };
88 
make_query(uint16_t id,size_t size)89 bytevec make_query(uint16_t id, size_t size) {
90     bytevec vec(size);
91     vec[0] = id >> 8;
92     vec[1] = id;
93     // Arbitrarily fill the query body with unique data.
94     for (size_t i = 2; i < size; ++i) {
95         vec[i] = id + i;
96     }
97     return vec;
98 }
99 
100 // Query constants
101 const unsigned MARK = 123;
102 const uint16_t ID = 52;
103 const uint16_t SIZE = 22;
104 const bytevec QUERY = make_query(ID, SIZE);
105 
106 template <class T>
107 class FakeSocketFactory : public IDnsTlsSocketFactory {
108   public:
FakeSocketFactory()109     FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)110     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
111             const DnsTlsServer& server ATTRIBUTE_UNUSED,
112             unsigned mark ATTRIBUTE_UNUSED,
113             IDnsTlsSocketObserver* observer,
114             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
115         return std::make_unique<T>(observer);
116     }
117 };
118 
make_echo(uint16_t id,const Slice query)119 bytevec make_echo(uint16_t id, const Slice query) {
120     bytevec response(query.size() + 2);
121     response[0] = id >> 8;
122     response[1] = id;
123     // Echo the query as the fake response.
124     memcpy(response.data() + 2, query.base(), query.size());
125     return response;
126 }
127 
128 // Simplest possible fake server.  This just echoes the query as the response.
129 class FakeSocketEcho : public IDnsTlsSocket {
130   public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)131     explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)132     bool query(uint16_t id, const Slice query) override {
133         // Return the response immediately (asynchronously).
134         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
135         return true;
136     }
137 
138   private:
139     IDnsTlsSocketObserver* const mObserver;
140 };
141 
142 class TransportTest : public BaseTest {};
143 
TEST_F(TransportTest,Query)144 TEST_F(TransportTest, Query) {
145     FakeSocketFactory<FakeSocketEcho> factory;
146     DnsTlsTransport transport(SERVER1, MARK, &factory);
147     auto r = transport.query(makeSlice(QUERY)).get();
148 
149     EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
150     EXPECT_EQ(QUERY, r.response);
151     EXPECT_EQ(transport.getConnectCounter(), 1);
152 }
153 
154 // Fake Socket that echoes the observed query ID as the response body.
155 class FakeSocketId : public IDnsTlsSocket {
156   public:
FakeSocketId(IDnsTlsSocketObserver * observer)157     explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query ATTRIBUTE_UNUSED)158     bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
159         // Return the response immediately (asynchronously).
160         bytevec response(4);
161         // Echo the ID in the header to match the response to the query.
162         // This will be overwritten by DnsTlsQueryMap.
163         response[0] = id >> 8;
164         response[1] = id;
165         // Echo the ID in the body, so that the test can verify which ID was used by
166         // DnsTlsQueryMap.
167         response[2] = id >> 8;
168         response[3] = id;
169         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
170         return true;
171     }
172 
173   private:
174     IDnsTlsSocketObserver* const mObserver;
175 };
176 
177 // Test that IDs are properly reused
TEST_F(TransportTest,IdReuse)178 TEST_F(TransportTest, IdReuse) {
179     FakeSocketFactory<FakeSocketId> factory;
180     DnsTlsTransport transport(SERVER1, MARK, &factory);
181     for (int i = 0; i < 100; ++i) {
182         // Send a query.
183         std::future<DnsTlsTransport::Result> f = transport.query(makeSlice(QUERY));
184         // Wait for the response.
185         DnsTlsTransport::Result r = f.get();
186         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
187 
188         // All queries should have an observed ID of zero, because it is returned to the ID pool
189         // after each use.
190         EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
191     }
192     EXPECT_EQ(transport.getConnectCounter(), 1);
193 }
194 
195 // These queries might be handled in serial or parallel as they race the
196 // responses.
TEST_F(TransportTest,RacingQueries_10000)197 TEST_F(TransportTest, RacingQueries_10000) {
198     FakeSocketFactory<FakeSocketEcho> factory;
199     DnsTlsTransport transport(SERVER1, MARK, &factory);
200     std::vector<std::future<DnsTlsTransport::Result>> results;
201     // Fewer than 65536 queries to avoid ID exhaustion.
202     const int num_queries = 10000;
203     results.reserve(num_queries);
204     for (int i = 0; i < num_queries; ++i) {
205         results.push_back(transport.query(makeSlice(QUERY)));
206     }
207     for (auto& result : results) {
208         auto r = result.get();
209         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
210         EXPECT_EQ(QUERY, r.response);
211     }
212     EXPECT_EQ(transport.getConnectCounter(), 1);
213 }
214 
215 // A server that waits until sDelay queries are queued before responding.
216 class FakeSocketDelay : public IDnsTlsSocket {
217   public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)218     explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()219     ~FakeSocketDelay() { std::lock_guard guard(mLock); }
220     static size_t sDelay;
221     static bool sReverse;
222 
query(uint16_t id,const Slice query)223     bool query(uint16_t id, const Slice query) override {
224         LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
225         std::lock_guard guard(mLock);
226         // Check for duplicate IDs.
227         EXPECT_EQ(0U, mIds.count(id));
228         mIds.insert(id);
229 
230         // Store response.
231         mResponses.push_back(make_echo(id, query));
232 
233         LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
234         if (mResponses.size() == sDelay) {
235             std::thread(&FakeSocketDelay::sendResponses, this).detach();
236         }
237         return true;
238     }
239 
240   private:
sendResponses()241     void sendResponses() {
242         std::lock_guard guard(mLock);
243         if (sReverse) {
244             std::reverse(std::begin(mResponses), std::end(mResponses));
245         }
246         for (auto& response : mResponses) {
247             mObserver->onResponse(response);
248         }
249         mIds.clear();
250         mResponses.clear();
251     }
252 
253     std::mutex mLock;
254     IDnsTlsSocketObserver* const mObserver;
255     std::set<uint16_t> mIds GUARDED_BY(mLock);
256     std::vector<bytevec> mResponses GUARDED_BY(mLock);
257 };
258 
259 size_t FakeSocketDelay::sDelay;
260 bool FakeSocketDelay::sReverse;
261 
TEST_F(TransportTest,ParallelColliding)262 TEST_F(TransportTest, ParallelColliding) {
263     FakeSocketDelay::sDelay = 10;
264     FakeSocketDelay::sReverse = false;
265     FakeSocketFactory<FakeSocketDelay> factory;
266     DnsTlsTransport transport(SERVER1, MARK, &factory);
267     std::vector<std::future<DnsTlsTransport::Result>> results;
268     // Fewer than 65536 queries to avoid ID exhaustion.
269     results.reserve(FakeSocketDelay::sDelay);
270     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
271         results.push_back(transport.query(makeSlice(QUERY)));
272     }
273     for (auto& result : results) {
274         auto r = result.get();
275         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
276         EXPECT_EQ(QUERY, r.response);
277     }
278     EXPECT_EQ(transport.getConnectCounter(), 1);
279 }
280 
TEST_F(TransportTest,ParallelColliding_Max)281 TEST_F(TransportTest, ParallelColliding_Max) {
282     FakeSocketDelay::sDelay = 65536;
283     FakeSocketDelay::sReverse = false;
284     FakeSocketFactory<FakeSocketDelay> factory;
285     DnsTlsTransport transport(SERVER1, MARK, &factory);
286     std::vector<std::future<DnsTlsTransport::Result>> results;
287     // Exactly 65536 queries should still be possible in parallel,
288     // even if they all have the same original ID.
289     results.reserve(FakeSocketDelay::sDelay);
290     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
291         results.push_back(transport.query(makeSlice(QUERY)));
292     }
293     for (auto& result : results) {
294         auto r = result.get();
295         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
296         EXPECT_EQ(QUERY, r.response);
297     }
298     EXPECT_EQ(transport.getConnectCounter(), 1);
299 }
300 
TEST_F(TransportTest,ParallelUnique)301 TEST_F(TransportTest, ParallelUnique) {
302     FakeSocketDelay::sDelay = 10;
303     FakeSocketDelay::sReverse = false;
304     FakeSocketFactory<FakeSocketDelay> factory;
305     DnsTlsTransport transport(SERVER1, MARK, &factory);
306     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
307     std::vector<std::future<DnsTlsTransport::Result>> results;
308     results.reserve(FakeSocketDelay::sDelay);
309     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
310         queries[i] = make_query(i, SIZE);
311         results.push_back(transport.query(makeSlice(queries[i])));
312     }
313     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
314         auto r = results[i].get();
315         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
316         EXPECT_EQ(queries[i], r.response);
317     }
318     EXPECT_EQ(transport.getConnectCounter(), 1);
319 }
320 
TEST_F(TransportTest,ParallelUnique_Max)321 TEST_F(TransportTest, ParallelUnique_Max) {
322     FakeSocketDelay::sDelay = 65536;
323     FakeSocketDelay::sReverse = false;
324     FakeSocketFactory<FakeSocketDelay> factory;
325     DnsTlsTransport transport(SERVER1, MARK, &factory);
326     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
327     std::vector<std::future<DnsTlsTransport::Result>> results;
328     // Exactly 65536 queries should still be possible in parallel,
329     // and they should all be mapped correctly back to the original ID.
330     results.reserve(FakeSocketDelay::sDelay);
331     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
332         queries[i] = make_query(i, SIZE);
333         results.push_back(transport.query(makeSlice(queries[i])));
334     }
335     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
336         auto r = results[i].get();
337         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
338         EXPECT_EQ(queries[i], r.response);
339     }
340     EXPECT_EQ(transport.getConnectCounter(), 1);
341 }
342 
TEST_F(TransportTest,IdExhaustion)343 TEST_F(TransportTest, IdExhaustion) {
344     const int num_queries = 65536;
345     // A delay of 65537 is unreachable, because the maximum number
346     // of outstanding queries is 65536.
347     FakeSocketDelay::sDelay = num_queries + 1;
348     FakeSocketDelay::sReverse = false;
349     FakeSocketFactory<FakeSocketDelay> factory;
350     DnsTlsTransport transport(SERVER1, MARK, &factory);
351     std::vector<std::future<DnsTlsTransport::Result>> results;
352     // Issue the maximum number of queries.
353     results.reserve(num_queries);
354     for (int i = 0; i < num_queries; ++i) {
355         results.push_back(transport.query(makeSlice(QUERY)));
356     }
357 
358     // The ID space is now full, so subsequent queries should fail immediately.
359     auto r = transport.query(makeSlice(QUERY)).get();
360     EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
361     EXPECT_TRUE(r.response.empty());
362 
363     for (auto& result : results) {
364         // All other queries should remain outstanding.
365         EXPECT_EQ(std::future_status::timeout,
366                 result.wait_for(std::chrono::duration<int>::zero()));
367     }
368     EXPECT_EQ(transport.getConnectCounter(), 1);
369 }
370 
371 // Responses can come back from the server in any order.  This should have no
372 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)373 TEST_F(TransportTest, ReverseOrder) {
374     FakeSocketDelay::sDelay = 10;
375     FakeSocketDelay::sReverse = true;
376     FakeSocketFactory<FakeSocketDelay> factory;
377     DnsTlsTransport transport(SERVER1, MARK, &factory);
378     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
379     std::vector<std::future<DnsTlsTransport::Result>> results;
380     results.reserve(FakeSocketDelay::sDelay);
381     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
382         queries[i] = make_query(i, SIZE);
383         results.push_back(transport.query(makeSlice(queries[i])));
384     }
385     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
386         auto r = results[i].get();
387         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
388         EXPECT_EQ(queries[i], r.response);
389     }
390     EXPECT_EQ(transport.getConnectCounter(), 1);
391 }
392 
TEST_F(TransportTest,ReverseOrder_Max)393 TEST_F(TransportTest, ReverseOrder_Max) {
394     FakeSocketDelay::sDelay = 65536;
395     FakeSocketDelay::sReverse = true;
396     FakeSocketFactory<FakeSocketDelay> factory;
397     DnsTlsTransport transport(SERVER1, MARK, &factory);
398     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
399     std::vector<std::future<DnsTlsTransport::Result>> results;
400     results.reserve(FakeSocketDelay::sDelay);
401     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
402         queries[i] = make_query(i, SIZE);
403         results.push_back(transport.query(makeSlice(queries[i])));
404     }
405     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
406         auto r = results[i].get();
407         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
408         EXPECT_EQ(queries[i], r.response);
409     }
410     EXPECT_EQ(transport.getConnectCounter(), 1);
411 }
412 
413 // Returning null from the factory indicates a connection failure.
414 class NullSocketFactory : public IDnsTlsSocketFactory {
415   public:
NullSocketFactory()416     NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)417     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
418             const DnsTlsServer& server ATTRIBUTE_UNUSED,
419             unsigned mark ATTRIBUTE_UNUSED,
420             IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
421             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
422         return nullptr;
423     }
424 };
425 
TEST_F(TransportTest,ConnectFail)426 TEST_F(TransportTest, ConnectFail) {
427     NullSocketFactory factory;
428     DnsTlsTransport transport(SERVER1, MARK, &factory);
429     auto r = transport.query(makeSlice(QUERY)).get();
430 
431     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
432     EXPECT_TRUE(r.response.empty());
433     EXPECT_EQ(transport.getConnectCounter(), 1);
434 }
435 
436 // Simulate a socket that connects but then immediately receives a server
437 // close notification.
438 class FakeSocketClose : public IDnsTlsSocket {
439   public:
FakeSocketClose(IDnsTlsSocketObserver * observer)440     explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
441         : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()442     ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)443     bool query(uint16_t id ATTRIBUTE_UNUSED,
444                const Slice query ATTRIBUTE_UNUSED) override {
445         return true;
446     }
447 
448   private:
449     std::thread mCloser;
450 };
451 
TEST_F(TransportTest,CloseRetryFail)452 TEST_F(TransportTest, CloseRetryFail) {
453     FakeSocketFactory<FakeSocketClose> factory;
454     DnsTlsTransport transport(SERVER1, MARK, &factory);
455     auto r = transport.query(makeSlice(QUERY)).get();
456 
457     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
458     EXPECT_TRUE(r.response.empty());
459 
460     // Reconnections are triggered since DnsTlsQueryMap is not empty.
461     EXPECT_EQ(transport.getConnectCounter(), DnsTlsQueryMap::kMaxTries);
462 }
463 
464 // Simulate a server that occasionally closes the connection and silently
465 // drops some queries.
466 class FakeSocketLimited : public IDnsTlsSocket {
467   public:
468     static int sLimit;  // Number of queries to answer per socket.
469     static size_t sMaxSize;  // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)470     explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
471         : mObserver(observer), mQueries(0) {}
~FakeSocketLimited()472     ~FakeSocketLimited() {
473         {
474             LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
475             std::lock_guard guard(mLock);
476             LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
477             for (auto& thread : mThreads) {
478                 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
479                 thread.join();
480                 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
481             }
482             mThreads.clear();
483         }
484 
485         if (mCloser) {
486             LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
487             mCloser->join();
488             LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
489         }
490     }
query(uint16_t id,const Slice query)491     bool query(uint16_t id, const Slice query) override {
492         LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
493         std::lock_guard guard(mLock);
494         LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
495         ++mQueries;
496 
497         if (mQueries <= sLimit) {
498             LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
499             if (query.size() <= sMaxSize) {
500                 // Return the response immediately (asynchronously).
501                 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
502             }
503         }
504         if (mQueries == sLimit) {
505             mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
506         }
507         return mQueries <= sLimit;
508     }
509 
510   private:
sendClose()511     void sendClose() {
512         {
513             LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
514             std::lock_guard guard(mLock);
515             LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
516             for (auto& thread : mThreads) {
517                 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
518                 thread.join();
519                 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
520             }
521             mThreads.clear();
522         }
523         mObserver->onClosed();
524     }
525     std::mutex mLock;
526     IDnsTlsSocketObserver* const mObserver;
527     int mQueries GUARDED_BY(mLock);
528     std::vector<std::thread> mThreads GUARDED_BY(mLock);
529     std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
530 };
531 
532 int FakeSocketLimited::sLimit;
533 size_t FakeSocketLimited::sMaxSize;
534 
TEST_F(TransportTest,SilentDrop)535 TEST_F(TransportTest, SilentDrop) {
536     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
537     FakeSocketLimited::sMaxSize = 0;  // Silently drop all queries
538     FakeSocketFactory<FakeSocketLimited> factory;
539     DnsTlsTransport transport(SERVER1, MARK, &factory);
540 
541     // Queue up 10 queries.  They will all be ignored, and after the 10th,
542     // the socket will close.  Transport will retry them all, until they
543     // all hit the retry limit and expire.
544     std::vector<std::future<DnsTlsTransport::Result>> results;
545     results.reserve(FakeSocketLimited::sLimit);
546     for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
547         results.push_back(transport.query(makeSlice(QUERY)));
548     }
549     for (auto& result : results) {
550         auto r = result.get();
551         EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
552         EXPECT_TRUE(r.response.empty());
553     }
554 
555     // Reconnections are triggered since DnsTlsQueryMap is not empty.
556     EXPECT_EQ(transport.getConnectCounter(), DnsTlsQueryMap::kMaxTries);
557 }
558 
TEST_F(TransportTest,PartialDrop)559 TEST_F(TransportTest, PartialDrop) {
560     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
561     FakeSocketLimited::sMaxSize = SIZE - 2;  // Silently drop "long" queries
562     FakeSocketFactory<FakeSocketLimited> factory;
563     DnsTlsTransport transport(SERVER1, MARK, &factory);
564 
565     // Queue up 100 queries, alternating "short" which will be served and "long"
566     // which will be dropped.
567     const int num_queries = 10 * FakeSocketLimited::sLimit;
568     std::vector<bytevec> queries(num_queries);
569     std::vector<std::future<DnsTlsTransport::Result>> results;
570     results.reserve(num_queries);
571     for (int i = 0; i < num_queries; ++i) {
572         queries[i] = make_query(i, SIZE + (i % 2));
573         results.push_back(transport.query(makeSlice(queries[i])));
574     }
575     // Just check the short queries, which are at the even indices.
576     for (int i = 0; i < num_queries; i += 2) {
577         auto r = results[i].get();
578         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
579         EXPECT_EQ(queries[i], r.response);
580     }
581 
582     // TODO: transport.getConnectCounter() seems not stable in this test. Find how to check the
583     // connect attempts for this test.
584 }
585 
TEST_F(TransportTest,ConnectCounter)586 TEST_F(TransportTest, ConnectCounter) {
587     FakeSocketLimited::sLimit = 2;       // Close the socket after 2 queries.
588     FakeSocketLimited::sMaxSize = SIZE;  // No query drops.
589     FakeSocketFactory<FakeSocketLimited> factory;
590     DnsTlsTransport transport(SERVER1, MARK, &factory);
591 
592     // Connecting on demand.
593     EXPECT_EQ(transport.getConnectCounter(), 0);
594 
595     const int num_queries = 10;
596     std::vector<std::future<DnsTlsTransport::Result>> results;
597     results.reserve(num_queries);
598     for (int i = 0; i < num_queries; i++) {
599         // Reconnections take place every two queries.
600         results.push_back(transport.query(makeSlice(QUERY)));
601     }
602     for (int i = 0; i < num_queries; i++) {
603         auto r = results[i].get();
604         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
605     }
606 
607     EXPECT_EQ(transport.getConnectCounter(), num_queries / FakeSocketLimited::sLimit);
608 }
609 
610 // Simulate a malfunctioning server that injects extra miscellaneous
611 // responses to queries that were not asked.  This will cause wrong answers but
612 // must not crash the Transport.
613 class FakeSocketGarbage : public IDnsTlsSocket {
614   public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)615     explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
616         // Inject a garbage event.
617         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
618     }
~FakeSocketGarbage()619     ~FakeSocketGarbage() {
620         std::lock_guard guard(mLock);
621         for (auto& thread : mThreads) {
622             thread.join();
623         }
624     }
query(uint16_t id,const Slice query)625     bool query(uint16_t id, const Slice query) override {
626         std::lock_guard guard(mLock);
627         // Return the response twice.
628         auto echo = make_echo(id, query);
629         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
630         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
631         // Also return some other garbage
632         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
633         return true;
634     }
635 
636   private:
637     std::mutex mLock;
638     std::vector<std::thread> mThreads GUARDED_BY(mLock);
639     IDnsTlsSocketObserver* const mObserver;
640 };
641 
TEST_F(TransportTest,IgnoringGarbage)642 TEST_F(TransportTest, IgnoringGarbage) {
643     FakeSocketFactory<FakeSocketGarbage> factory;
644     DnsTlsTransport transport(SERVER1, MARK, &factory);
645     for (int i = 0; i < 10; ++i) {
646         auto r = transport.query(makeSlice(QUERY)).get();
647 
648         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
649         // Don't check the response because this server is malfunctioning.
650     }
651     EXPECT_EQ(transport.getConnectCounter(), 1);
652 }
653 
654 // Dispatcher tests
655 class DispatcherTest : public BaseTest {};
656 
TEST_F(DispatcherTest,Query)657 TEST_F(DispatcherTest, Query) {
658     bytevec ans(4096);
659     int resplen = 0;
660     bool connectTriggered = false;
661 
662     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
663     DnsTlsDispatcher dispatcher(std::move(factory));
664     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
665                               &connectTriggered);
666 
667     EXPECT_EQ(DnsTlsTransport::Response::success, r);
668     EXPECT_EQ(int(QUERY.size()), resplen);
669     EXPECT_TRUE(connectTriggered);
670     ans.resize(resplen);
671     EXPECT_EQ(QUERY, ans);
672 
673     // Expect to reuse the connection.
674     r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
675                          &connectTriggered);
676     EXPECT_EQ(DnsTlsTransport::Response::success, r);
677     EXPECT_FALSE(connectTriggered);
678 }
679 
TEST_F(DispatcherTest,AnswerTooLarge)680 TEST_F(DispatcherTest, AnswerTooLarge) {
681     bytevec ans(SIZE - 1);  // Too small to hold the answer
682     int resplen = 0;
683     bool connectTriggered = false;
684 
685     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
686     DnsTlsDispatcher dispatcher(std::move(factory));
687     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
688                               &connectTriggered);
689 
690     EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
691     EXPECT_TRUE(connectTriggered);
692 }
693 
694 template<class T>
695 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
696   public:
TrackingFakeSocketFactory()697     TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)698     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
699             const DnsTlsServer& server,
700             unsigned mark,
701             IDnsTlsSocketObserver* observer,
702             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
703         std::lock_guard guard(mLock);
704         keys.emplace(mark, server);
705         return std::make_unique<T>(observer);
706     }
707     std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
708 
709   private:
710     std::mutex mLock;
711 };
712 
TEST_F(DispatcherTest,Dispatching)713 TEST_F(DispatcherTest, Dispatching) {
714     FakeSocketDelay::sDelay = 5;
715     FakeSocketDelay::sReverse = true;
716     auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
717     auto* weak_factory = factory.get();  // Valid as long as dispatcher is in scope.
718     DnsTlsDispatcher dispatcher(std::move(factory));
719 
720     // Populate a vector of two servers and two socket marks, four combinations
721     // in total.
722     std::vector<std::pair<unsigned, DnsTlsServer>> keys;
723     keys.emplace_back(MARK, SERVER1);
724     keys.emplace_back(MARK + 1, SERVER1);
725     keys.emplace_back(MARK, V4ADDR2);
726     keys.emplace_back(MARK + 1, V4ADDR2);
727 
728     // Do several queries on each server.  They should all succeed.
729     std::vector<std::thread> threads;
730     for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
731         auto key = keys[i % keys.size()];
732         threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
733             auto q = make_query(i, SIZE);
734             bytevec ans(4096);
735             int resplen = 0;
736             bool connectTriggered = false;
737             unsigned mark = key.first;
738             const DnsTlsServer& server = key.second;
739             auto r = dispatcher->query(server, mark, makeSlice(q), makeSlice(ans), &resplen,
740                                        &connectTriggered);
741             EXPECT_EQ(DnsTlsTransport::Response::success, r);
742             EXPECT_EQ(int(q.size()), resplen);
743             ans.resize(resplen);
744             EXPECT_EQ(q, ans);
745         }, &dispatcher);
746     }
747     for (auto& thread : threads) {
748         thread.join();
749     }
750     // We expect that the factory created one socket for each key.
751     EXPECT_EQ(keys.size(), weak_factory->keys.size());
752     for (auto& key : keys) {
753         EXPECT_EQ(1U, weak_factory->keys.count(key));
754     }
755 }
756 
757 // Check DnsTlsServer's comparison logic.
758 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)759 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
760     bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
761     bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
762     EXPECT_FALSE(cmp1 && cmp2);
763     return !cmp1 && !cmp2;
764 }
765 
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)766 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
767     EXPECT_TRUE(s1 == s1);
768     EXPECT_TRUE(s2 == s2);
769     EXPECT_TRUE(isAddressEqual(s1, s1));
770     EXPECT_TRUE(isAddressEqual(s2, s2));
771 
772     EXPECT_TRUE(s1 < s2 ^ s2 < s1);
773     EXPECT_FALSE(s1 == s2);
774     EXPECT_FALSE(s2 == s1);
775 }
776 
777 class ServerTest : public BaseTest {};
778 
TEST_F(ServerTest,IPv4)779 TEST_F(ServerTest, IPv4) {
780     checkUnequal(V4ADDR1, V4ADDR2);
781     EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
782 }
783 
TEST_F(ServerTest,IPv6)784 TEST_F(ServerTest, IPv6) {
785     checkUnequal(V6ADDR1, V6ADDR2);
786     EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
787 }
788 
TEST_F(ServerTest,MixedAddressFamily)789 TEST_F(ServerTest, MixedAddressFamily) {
790     checkUnequal(V6ADDR1, V4ADDR1);
791     EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
792 }
793 
TEST_F(ServerTest,IPv6ScopeId)794 TEST_F(ServerTest, IPv6ScopeId) {
795     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
796     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
797     addr1->sin6_scope_id = 1;
798     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
799     addr2->sin6_scope_id = 2;
800     checkUnequal(s1, s2);
801     EXPECT_FALSE(isAddressEqual(s1, s2));
802 
803     EXPECT_FALSE(s1.wasExplicitlyConfigured());
804     EXPECT_FALSE(s2.wasExplicitlyConfigured());
805 }
806 
TEST_F(ServerTest,IPv6FlowInfo)807 TEST_F(ServerTest, IPv6FlowInfo) {
808     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
809     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
810     addr1->sin6_flowinfo = 1;
811     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
812     addr2->sin6_flowinfo = 2;
813     // All comparisons ignore flowinfo.
814     EXPECT_EQ(s1, s2);
815     EXPECT_TRUE(isAddressEqual(s1, s2));
816 
817     EXPECT_FALSE(s1.wasExplicitlyConfigured());
818     EXPECT_FALSE(s2.wasExplicitlyConfigured());
819 }
820 
TEST_F(ServerTest,Port)821 TEST_F(ServerTest, Port) {
822     DnsTlsServer s1, s2;
823     parseServer("192.0.2.1", 853, &s1.ss);
824     parseServer("192.0.2.1", 854, &s2.ss);
825     checkUnequal(s1, s2);
826     EXPECT_TRUE(isAddressEqual(s1, s2));
827 
828     DnsTlsServer s3, s4;
829     parseServer("2001:db8::1", 853, &s3.ss);
830     parseServer("2001:db8::1", 852, &s4.ss);
831     checkUnequal(s3, s4);
832     EXPECT_TRUE(isAddressEqual(s3, s4));
833 
834     EXPECT_FALSE(s1.wasExplicitlyConfigured());
835     EXPECT_FALSE(s2.wasExplicitlyConfigured());
836 }
837 
TEST_F(ServerTest,Name)838 TEST_F(ServerTest, Name) {
839     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
840     s1.name = SERVERNAME1;
841     checkUnequal(s1, s2);
842     s2.name = SERVERNAME2;
843     checkUnequal(s1, s2);
844     EXPECT_TRUE(isAddressEqual(s1, s2));
845 
846     EXPECT_TRUE(s1.wasExplicitlyConfigured());
847     EXPECT_TRUE(s2.wasExplicitlyConfigured());
848 }
849 
TEST_F(ServerTest,Timeout)850 TEST_F(ServerTest, Timeout) {
851     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
852     s1.connectTimeout = std::chrono::milliseconds(4000);
853     checkUnequal(s1, s2);
854     s2.connectTimeout = std::chrono::milliseconds(4000);
855     EXPECT_EQ(s1, s2);
856     EXPECT_TRUE(isAddressEqual(s1, s2));
857 
858     EXPECT_FALSE(s1.wasExplicitlyConfigured());
859     EXPECT_FALSE(s2.wasExplicitlyConfigured());
860 }
861 
TEST(QueryMapTest,Basic)862 TEST(QueryMapTest, Basic) {
863     DnsTlsQueryMap map;
864 
865     EXPECT_TRUE(map.empty());
866 
867     bytevec q0 = make_query(999, SIZE);
868     bytevec q1 = make_query(888, SIZE);
869     bytevec q2 = make_query(777, SIZE);
870 
871     auto f0 = map.recordQuery(makeSlice(q0));
872     auto f1 = map.recordQuery(makeSlice(q1));
873     auto f2 = map.recordQuery(makeSlice(q2));
874 
875     // Check return values of recordQuery
876     EXPECT_EQ(0, f0->query.newId);
877     EXPECT_EQ(1, f1->query.newId);
878     EXPECT_EQ(2, f2->query.newId);
879 
880     // Check side effects of recordQuery
881     EXPECT_FALSE(map.empty());
882 
883     auto all = map.getAll();
884     EXPECT_EQ(3U, all.size());
885 
886     EXPECT_EQ(0, all[0].newId);
887     EXPECT_EQ(1, all[1].newId);
888     EXPECT_EQ(2, all[2].newId);
889 
890     EXPECT_EQ(q0, all[0].query);
891     EXPECT_EQ(q1, all[1].query);
892     EXPECT_EQ(q2, all[2].query);
893 
894     bytevec a0 = make_query(0, SIZE);
895     bytevec a1 = make_query(1, SIZE);
896     bytevec a2 = make_query(2, SIZE);
897 
898     // Return responses out of order
899     map.onResponse(a2);
900     map.onResponse(a0);
901     map.onResponse(a1);
902 
903     EXPECT_TRUE(map.empty());
904 
905     auto r0 = f0->result.get();
906     auto r1 = f1->result.get();
907     auto r2 = f2->result.get();
908 
909     EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
910     EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
911     EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
912 
913     const bytevec& d0 = r0.response;
914     const bytevec& d1 = r1.response;
915     const bytevec& d2 = r2.response;
916 
917     // The ID should match the query
918     EXPECT_EQ(999, d0[0] << 8 | d0[1]);
919     EXPECT_EQ(888, d1[0] << 8 | d1[1]);
920     EXPECT_EQ(777, d2[0] << 8 | d2[1]);
921     // The body should match the answer
922     EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
923     EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
924     EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
925 }
926 
TEST(QueryMapTest,FillHole)927 TEST(QueryMapTest, FillHole) {
928     DnsTlsQueryMap map;
929     std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
930     for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
931         futures[i] = map.recordQuery(makeSlice(QUERY));
932         ASSERT_TRUE(futures[i]);  // answers[i] should be nonnull.
933         EXPECT_EQ(i, futures[i]->query.newId);
934     }
935 
936     // The map should now be full.
937     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
938 
939     // Trying to add another query should fail because the map is full.
940     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
941 
942     // Send an answer to query 40000
943     auto answer = make_query(40000, SIZE);
944     map.onResponse(answer);
945     auto result = futures[40000]->result.get();
946     EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
947     EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
948     EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
949               bytevec(result.response.begin() + 2, result.response.end()));
950 
951     // There should now be room in the map.
952     EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
953     auto f = map.recordQuery(makeSlice(QUERY));
954     ASSERT_TRUE(f);
955     EXPECT_EQ(40000, f->query.newId);
956 
957     // The map should now be full again.
958     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
959     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
960 }
961 
962 class StubObserver : public IDnsTlsSocketObserver {
963   public:
964     bool closed = false;
onResponse(std::vector<uint8_t>)965     void onResponse(std::vector<uint8_t>) override {}
966 
onClosed()967     void onClosed() override { closed = true; }
968 };
969 
TEST(DnsTlsSocketTest,SlowDestructor)970 TEST(DnsTlsSocketTest, SlowDestructor) {
971     constexpr char tls_addr[] = "127.0.0.3";
972     constexpr char tls_port[] = "8530";  // High-numbered port so root isn't required.
973     // This test doesn't perform any queries, so the backend address can be invalid.
974     constexpr char backend_addr[] = "192.0.2.1";
975     constexpr char backend_port[] = "1";
976 
977     test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
978     ASSERT_TRUE(tls.startServer());
979 
980     DnsTlsServer server;
981     parseServer(tls_addr, 8530, &server.ss);
982 
983     StubObserver observer;
984     ASSERT_FALSE(observer.closed);
985     DnsTlsSessionCache cache;
986     auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
987     ASSERT_TRUE(socket->initialize());
988 
989     // Test: Time the socket destructor.  This should be fast.
990     auto before = std::chrono::steady_clock::now();
991     socket.reset();
992     auto after = std::chrono::steady_clock::now();
993     auto delay = after - before;
994     LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
995     EXPECT_TRUE(observer.closed);
996     // Shutdown should complete in milliseconds, but if the shutdown signal is lost
997     // it will wait for the timeout, which is expected to take 20seconds.
998     EXPECT_LT(delay, std::chrono::seconds{5});
999 }
1000 
1001 } // end of namespace net
1002 } // end of namespace android
1003