1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef _DNS_DNSTLSSOCKET_H 18 #define _DNS_DNSTLSSOCKET_H 19 20 #include <openssl/ssl.h> 21 #include <future> 22 #include <mutex> 23 24 #include <android-base/thread_annotations.h> 25 #include <android-base/unique_fd.h> 26 #include <netdutils/Slice.h> 27 #include <netdutils/Status.h> 28 29 #include "DnsTlsServer.h" 30 #include "IDnsTlsSocket.h" 31 #include "LockedQueue.h" 32 33 namespace android { 34 namespace net { 35 36 class IDnsTlsSocketObserver; 37 class DnsTlsSessionCache; 38 39 // A class for managing a TLS socket that sends and receives messages in 40 // [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format). 41 // This class is not aware of query-response pairing or anything else about DNS. 42 // For the observer: 43 // This class is not re-entrant: the observer is not permitted to wait for a call to query() 44 // or the destructor in a callback. Doing so will result in deadlocks. 45 // This class may call the observer at any time after initialize(), until the destructor 46 // returns (but not after). 47 // 48 // Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle: 49 // 50 // UNINITIALIZED 51 // | 52 // v 53 // INITIALIZED 54 // | 55 // v 56 // +----CONNECTING------+ 57 // Handshake fails | | Handshake succeeds 58 // (onClose() when | | 59 // mAsyncHandshake is set) | v 60 // | +---> CONNECTED --+ 61 // | | | | 62 // | +-----------+ | Idle timeout 63 // | Send/Recv queries | onClose() 64 // | onResponse() | 65 // | | 66 // | | 67 // +--> WAIT_FOR_DELETE <-----+ 68 // 69 // 70 // TODO: Add onHandshakeFinished() for handshake results. 71 class DnsTlsSocket : public IDnsTlsSocket { 72 public: 73 enum class State { 74 UNINITIALIZED, 75 INITIALIZED, 76 CONNECTING, 77 CONNECTED, 78 WAIT_FOR_DELETE, 79 }; 80 81 DnsTlsSocket(const DnsTlsServer& server, unsigned mark, 82 IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache) 83 : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {} 84 ~DnsTlsSocket(); 85 86 // Creates the SSL context for this session. Returns false on failure. 87 // This method should be called after construction and before use of a DnsTlsSocket. 88 // Only call this method once per DnsTlsSocket. 89 bool initialize() EXCLUDES(mLock); 90 91 // If async handshake is enabled, this function simply signals a handshake request, and the 92 // handshake will be performed in the loop thread; otherwise, if async handshake is disabled, 93 // this function performs the handshake and returns after the handshake finishes. 94 bool startHandshake() EXCLUDES(mLock); 95 96 // Send a query on the provided SSL socket. |query| contains 97 // the body of a query, not including the ID header. This function will typically return before 98 // the query is actually sent. If this function fails, DnsTlsSocketObserver will be 99 // notified that the socket is closed. 100 // Note that success here indicates successful sending, not receipt of a response. 101 // Thread-safe. 102 bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock); 103 104 private: 105 // Lock to be held by the SSL event loop thread. This is not normally in contention. 106 std::mutex mLock; 107 108 // Forwards queries and receives responses. Blocks until the idle timeout. 109 void loop() EXCLUDES(mLock); 110 std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock); 111 112 // On success, sets mSslFd to a socket connected to mAddr (the 113 // connection will likely be in progress if mProtocol is IPPROTO_TCP). 114 // On error, returns the errno. 115 netdutils::Status tcpConnect() REQUIRES(mLock); 116 117 bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock); 118 119 // Connect an SSL session on the provided socket. If connection fails, closing the 120 // socket remains the caller's responsibility. 121 bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock); 122 123 // Connect an SSL session on the provided socket. This is an interruptible version 124 // which allows to terminate connection handshake any time. 125 bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock); 126 127 // Disconnect the SSL session and close the socket. 128 void sslDisconnect() REQUIRES(mLock); 129 130 // Writes a buffer to the socket. 131 bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock); 132 133 // Reads exactly the specified number of bytes from the socket, or fails. 134 // Returns SSL_ERROR_NONE on success. 135 // If |wait| is true, then this function always blocks. Otherwise, it 136 // will return SSL_ERROR_WANT_READ if there is no data from the server to read. 137 int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock); 138 139 bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); 140 141 // Read one DNS response. It can potentially block until reading the exact bytes of 142 // the response. 143 bool readResponse() REQUIRES(mLock); 144 145 // It is only used for DNS-OVER-TLS internal test. 146 bool setTestCaCertificate() REQUIRES(mLock); 147 148 // Similar to query(), this function uses incrementEventFd to send a message to the 149 // loop thread. However, instead of incrementing the counter by one (indicating a 150 // new query), it wraps the counter to negative, which we use to indicate a shutdown 151 // request. 152 void requestLoopShutdown() EXCLUDES(mLock); 153 154 // This function sends a message to the loop thread by incrementing mEventFd. 155 bool incrementEventFd(int64_t count) EXCLUDES(mLock); 156 157 // Transition the state from expected state |from| to new state |to|. 158 void transitionState(State from, State to) REQUIRES(mLock); 159 160 // Queue of pending queries. query() pushes items onto the queue and notifies 161 // the loop thread by incrementing mEventFd. loop() reads items off the queue. 162 LockedQueue<std::vector<uint8_t>> mQueue; 163 164 // eventfd socket used for notifying the SSL thread when queries are ready to send. 165 // This socket acts similarly to an atomic counter, incremented by query() and cleared 166 // by loop(). We have to use a socket because the SSL thread needs to wait in poll() 167 // for input from either a remote server or a query thread. Since eventfd does not have 168 // EOF, we indicate a close request by setting the counter to a negative number. 169 // This file descriptor is opened by initialize(), and closed implicitly after 170 // destruction. 171 // Note that: data starts being read from the eventfd when the state is CONNECTED. 172 base::unique_fd mEventFd; 173 174 // An eventfd used to listen to shutdown requests when the state is CONNECTING. 175 // TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively 176 // handle shutdown requests. 177 base::unique_fd mShutdownEvent; 178 179 // SSL Socket fields. 180 bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock); 181 base::unique_fd mSslFd GUARDED_BY(mLock); 182 bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock); 183 static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20); 184 185 const unsigned mMark; // Socket mark 186 const DnsTlsServer mServer; 187 IDnsTlsSocketObserver* _Nonnull const mObserver; 188 DnsTlsSessionCache* _Nonnull const mCache; 189 State mState GUARDED_BY(mLock) = State::UNINITIALIZED; 190 191 // If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's 192 // thread (the call to startHandshake()). 193 bool mAsyncHandshake GUARDED_BY(mLock) = false; 194 195 // The time to wait for the attempt on connecting to the server. 196 // Set the default value 127 seconds to be consistent with TCP connect timeout. 197 // (presume net.ipv4.tcp_syn_retries = 6) 198 static constexpr int kDotConnectTimeoutMs = 127 * 1000; 199 int mConnectTimeoutMs; 200 201 // For testing. 202 friend class DnsTlsSocketTest; 203 }; 204 205 } // end of namespace net 206 } // end of namespace android 207 208 #endif // _DNS_DNSTLSSOCKET_H 209