1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef _DNS_DNSTLSSOCKET_H
18 #define _DNS_DNSTLSSOCKET_H
19 
20 #include <openssl/ssl.h>
21 #include <future>
22 #include <mutex>
23 
24 #include <android-base/thread_annotations.h>
25 #include <android-base/unique_fd.h>
26 #include <netdutils/Slice.h>
27 #include <netdutils/Status.h>
28 
29 #include "DnsTlsServer.h"
30 #include "IDnsTlsSocket.h"
31 #include "LockedQueue.h"
32 
33 namespace android {
34 namespace net {
35 
36 class IDnsTlsSocketObserver;
37 class DnsTlsSessionCache;
38 
39 // A class for managing a TLS socket that sends and receives messages in
40 // [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
41 // This class is not aware of query-response pairing or anything else about DNS.
42 // For the observer:
43 // This class is not re-entrant: the observer is not permitted to wait for a call to query()
44 // or the destructor in a callback.  Doing so will result in deadlocks.
45 // This class may call the observer at any time after initialize(), until the destructor
46 // returns (but not after).
47 //
48 // Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
49 //
50 //                                UNINITIALIZED
51 //                                      |
52 //                                      v
53 //                                 INITIALIZED
54 //                                      |
55 //                                      v
56 //                            +----CONNECTING------+
57 //            Handshake fails |                    | Handshake succeeds
58 //   (onClose() when          |                    |
59 //    mAsyncHandshake is set) |                    v
60 //                            |        +---> CONNECTED --+
61 //                            |        |           |     |
62 //                            |        +-----------+     | Idle timeout
63 //                            |   Send/Recv queries      | onClose()
64 //                            |   onResponse()           |
65 //                            |                          |
66 //                            |                          |
67 //                            +--> WAIT_FOR_DELETE <-----+
68 //
69 //
70 // TODO: Add onHandshakeFinished() for handshake results.
71 class DnsTlsSocket : public IDnsTlsSocket {
72   public:
73     enum class State {
74         UNINITIALIZED,
75         INITIALIZED,
76         CONNECTING,
77         CONNECTED,
78         WAIT_FOR_DELETE,
79     };
80 
81     DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
82                  IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
83         : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
84     ~DnsTlsSocket();
85 
86     // Creates the SSL context for this session. Returns false on failure.
87     // This method should be called after construction and before use of a DnsTlsSocket.
88     // Only call this method once per DnsTlsSocket.
89     bool initialize() EXCLUDES(mLock);
90 
91     // If async handshake is enabled, this function simply signals a handshake request, and the
92     // handshake will be performed in the loop thread; otherwise, if async handshake is disabled,
93     // this function performs the handshake and returns after the handshake finishes.
94     bool startHandshake() EXCLUDES(mLock);
95 
96     // Send a query on the provided SSL socket.  |query| contains
97     // the body of a query, not including the ID header. This function will typically return before
98     // the query is actually sent.  If this function fails, DnsTlsSocketObserver will be
99     // notified that the socket is closed.
100     // Note that success here indicates successful sending, not receipt of a response.
101     // Thread-safe.
102     bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
103 
104   private:
105     // Lock to be held by the SSL event loop thread.  This is not normally in contention.
106     std::mutex mLock;
107 
108     // Forwards queries and receives responses.  Blocks until the idle timeout.
109     void loop() EXCLUDES(mLock);
110     std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
111 
112     // On success, sets mSslFd to a socket connected to mAddr (the
113     // connection will likely be in progress if mProtocol is IPPROTO_TCP).
114     // On error, returns the errno.
115     netdutils::Status tcpConnect() REQUIRES(mLock);
116 
117     bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock);
118 
119     // Connect an SSL session on the provided socket.  If connection fails, closing the
120     // socket remains the caller's responsibility.
121     bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
122 
123     // Connect an SSL session on the provided socket. This is an interruptible version
124     // which allows to terminate connection handshake any time.
125     bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock);
126 
127     // Disconnect the SSL session and close the socket.
128     void sslDisconnect() REQUIRES(mLock);
129 
130     // Writes a buffer to the socket.
131     bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);
132 
133     // Reads exactly the specified number of bytes from the socket, or fails.
134     // Returns SSL_ERROR_NONE on success.
135     // If |wait| is true, then this function always blocks.  Otherwise, it
136     // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
137     int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
138 
139     bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
140 
141     // Read one DNS response. It can potentially block until reading the exact bytes of
142     // the response.
143     bool readResponse() REQUIRES(mLock);
144 
145     // It is only used for DNS-OVER-TLS internal test.
146     bool setTestCaCertificate() REQUIRES(mLock);
147 
148     // Similar to query(), this function uses incrementEventFd to send a message to the
149     // loop thread.  However, instead of incrementing the counter by one (indicating a
150     // new query), it wraps the counter to negative, which we use to indicate a shutdown
151     // request.
152     void requestLoopShutdown() EXCLUDES(mLock);
153 
154     // This function sends a message to the loop thread by incrementing mEventFd.
155     bool incrementEventFd(int64_t count) EXCLUDES(mLock);
156 
157     // Transition the state from expected state |from| to new state |to|.
158     void transitionState(State from, State to) REQUIRES(mLock);
159 
160     // Queue of pending queries.  query() pushes items onto the queue and notifies
161     // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
162     LockedQueue<std::vector<uint8_t>> mQueue;
163 
164     // eventfd socket used for notifying the SSL thread when queries are ready to send.
165     // This socket acts similarly to an atomic counter, incremented by query() and cleared
166     // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
167     // for input from either a remote server or a query thread.  Since eventfd does not have
168     // EOF, we indicate a close request by setting the counter to a negative number.
169     // This file descriptor is opened by initialize(), and closed implicitly after
170     // destruction.
171     // Note that: data starts being read from the eventfd when the state is CONNECTED.
172     base::unique_fd mEventFd;
173 
174     // An eventfd used to listen to shutdown requests when the state is CONNECTING.
175     // TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively
176     // handle shutdown requests.
177     base::unique_fd mShutdownEvent;
178 
179     // SSL Socket fields.
180     bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
181     base::unique_fd mSslFd GUARDED_BY(mLock);
182     bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
183     static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
184 
185     const unsigned mMark;  // Socket mark
186     const DnsTlsServer mServer;
187     IDnsTlsSocketObserver* _Nonnull const mObserver;
188     DnsTlsSessionCache* _Nonnull const mCache;
189     State mState GUARDED_BY(mLock) = State::UNINITIALIZED;
190 
191     // If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's
192     // thread (the call to startHandshake()).
193     bool mAsyncHandshake GUARDED_BY(mLock) = false;
194 
195     // The time to wait for the attempt on connecting to the server.
196     // Set the default value 127 seconds to be consistent with TCP connect timeout.
197     // (presume net.ipv4.tcp_syn_retries = 6)
198     static constexpr int kDotConnectTimeoutMs = 127 * 1000;
199     int mConnectTimeoutMs;
200 
201     // For testing.
202     friend class DnsTlsSocketTest;
203 };
204 
205 }  // end of namespace net
206 }  // end of namespace android
207 
208 #endif  // _DNS_DNSTLSSOCKET_H
209