1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef _DNS_DNSTLSDISPATCHER_H 18 #define _DNS_DNSTLSDISPATCHER_H 19 20 #include <list> 21 #include <map> 22 #include <memory> 23 #include <mutex> 24 25 #include <android-base/thread_annotations.h> 26 #include <netdutils/Slice.h> 27 28 #include "DnsTlsServer.h" 29 #include "DnsTlsTransport.h" 30 #include "IDnsTlsSocketFactory.h" 31 #include "PrivateDnsValidationObserver.h" 32 #include "resolv_private.h" 33 34 namespace android { 35 namespace net { 36 37 // This is a singleton class that manages the collection of active DnsTlsTransports. 38 // Queries made here are dispatched to an existing or newly constructed DnsTlsTransport. 39 // TODO: PrivateDnsValidationObserver is not implemented in this class. Remove it. 40 class DnsTlsDispatcher : public PrivateDnsValidationObserver { 41 public: 42 // Constructor with dependency injection for testing. 43 explicit DnsTlsDispatcher(std::unique_ptr<IDnsTlsSocketFactory> factory) 44 : mFactory(std::move(factory)) {} 45 46 static DnsTlsDispatcher& getInstance(); 47 48 // Enqueues |query| for resolution via the given |tlsServers| on the 49 // network indicated by |mark|; writes the response into |ans|, and stores 50 // the count of bytes written in |resplen|. Returns a success or error code. 51 // The order in which servers from |tlsServers| are queried may not be the 52 // order passed in by the caller. 53 DnsTlsTransport::Response query(const std::list<DnsTlsServer>& tlsServers, 54 res_state _Nonnull statp, const netdutils::Slice query, 55 const netdutils::Slice ans, int* _Nonnull resplen); 56 57 // Given a |query|, sends it to the server on the network indicated by |mark|, 58 // and writes the response into |ans|, and indicates the number of bytes written in |resplen|. 59 // If the whole procedure above triggers (or experiences) any new connection, |connectTriggered| 60 // is set. Returns a success or error code. 61 DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned netId, unsigned mark, 62 const netdutils::Slice query, const netdutils::Slice ans, 63 int* _Nonnull resplen, bool* _Nonnull connectTriggered); 64 65 // Implement PrivateDnsValidationObserver. 66 void onValidationStateUpdate(const std::string&, Validation, uint32_t) override{}; 67 68 void forceCleanup(unsigned netId) EXCLUDES(sLock); 69 70 private: 71 DnsTlsDispatcher(); 72 73 // This lock is static so that it can be used to annotate the Transport struct. 74 // DnsTlsDispatcher is a singleton in practice, so making this static does not change 75 // the locking behavior. 76 static std::mutex sLock; 77 78 // Key = <mark, server> 79 typedef std::pair<unsigned, const DnsTlsServer> Key; 80 81 // Transport is a thin wrapper around DnsTlsTransport, adding reference counting and 82 // usage monitoring so we can expire idle sessions from the cache. 83 struct Transport { 84 Transport(const DnsTlsServer& server, unsigned mark, unsigned netId, 85 IDnsTlsSocketFactory* _Nonnull factory, bool revalidationEnabled, int triggerThr, 86 int unusableThr, int timeout) 87 : transport(server, mark, factory), 88 mNetId(netId), 89 revalidationEnabled(revalidationEnabled), 90 triggerThreshold(triggerThr), 91 unusableThreshold(unusableThr), 92 mTimeout(timeout) {} 93 94 // DnsTlsTransport is thread-safe, so it doesn't need to be guarded. 95 DnsTlsTransport transport; 96 97 // The expected network, assigned from dns_netid, to which Transport will send DNS packets. 98 const unsigned mNetId; 99 100 // This use counter and timestamp are used to ensure that only idle sessions are 101 // destroyed. 102 int useCount GUARDED_BY(sLock) = 0; 103 // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero. 104 std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock); 105 106 // If DoT revalidation is disabled, it returns true; otherwise, it returns 107 // whether or not this Transport is usable. 108 bool usable() const REQUIRES(sLock); 109 110 bool checkRevalidationNecessary(DnsTlsTransport::Response code) REQUIRES(sLock); 111 112 std::chrono::milliseconds timeout() const { return mTimeout; } 113 114 static constexpr int kDotRevalidationThreshold = -1; 115 static constexpr int kDotXportUnusableThreshold = -1; 116 static constexpr int kDotQueryTimeoutMs = -1; 117 118 private: 119 // Used to track if this Transport is usable. 120 int continuousfailureCount GUARDED_BY(sLock) = 0; 121 122 // Used to indicate whether DoT revalidation is enabled for this Transport. 123 // The value is set to true only if: 124 // 1. both triggerThreshold and unusableThreshold are positive values. 125 // 2. private DNS mode is opportunistic. 126 const bool revalidationEnabled; 127 128 // The number of continuous failures to trigger a validation. It takes effect when DoT 129 // revalidation is on. If the value is not a positive value, DoT revalidation is disabled. 130 // Note that it must be at least 10, or it breaks ConnectTlsServerTimeout_ConcurrentQueries 131 // test. 132 const int triggerThreshold; 133 134 // The threshold to determine if this Transport is considered unusable. 135 // If continuousfailureCount reaches this value, this Transport is no longer used. It 136 // takes effect when DoT revalidation is on. If the value is not a positive value, DoT 137 // revalidation is disabled. 138 const int unusableThreshold; 139 140 // The time to await a future (the result of a DNS request) from the DnsTlsTransport 141 // of this Transport. 142 // To set an infinite timeout, assign the value to -1. 143 const std::chrono::milliseconds mTimeout; 144 }; 145 146 Transport* _Nullable addTransport(const DnsTlsServer& server, unsigned mark, unsigned netId) 147 REQUIRES(sLock); 148 Transport* _Nullable getTransport(const Key& key) REQUIRES(sLock); 149 150 // Cache of reusable DnsTlsTransports. Transports stay in cache as long as 151 // they are in use and for a few minutes after. 152 std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock); 153 154 // The last time we did a cleanup. For efficiency, we only perform a cleanup once every 155 // few minutes. 156 std::chrono::time_point<std::chrono::steady_clock> mLastCleanup GUARDED_BY(sLock); 157 158 DnsTlsTransport::Result queryInternal(Transport& transport, const netdutils::Slice query) 159 EXCLUDES(sLock); 160 161 // Drop any cache entries whose useCount is zero and which have not been used recently. 162 // This function performs a linear scan of mStore. 163 void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock); 164 165 // Force dropping any Transports whose useCount is zero. 166 void forceCleanupLocked(unsigned netId) REQUIRES(sLock); 167 168 // Return a sorted list of usable DnsTlsServers in preference order. 169 std::list<DnsTlsServer> getOrderedAndUsableServerList(const std::list<DnsTlsServer>& tlsServers, 170 unsigned netId, unsigned mark); 171 172 // Trivial factory for DnsTlsSockets. Dependency injection is only used for testing. 173 std::unique_ptr<IDnsTlsSocketFactory> mFactory; 174 }; 175 176 } // end of namespace net 177 } // end of namespace android 178 179 #endif // _DNS_DNSTLSDISPATCHER_H 180