1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
17 #include <arpa/inet.h>
18 #include <fuzzer/FuzzedDataProvider.h>
19 #include <media/stagefright/foundation/AMessage.h>
20 #include <media/stagefright/rtsp/ARTSPConnection.h>
21 #include <thread>
23 using namespace android;
25 const std::string kAuthType[] = {"Basic", "Digest"};
26 const std::string kTab = "\t";
27 const std::string kCSeq = "CSeq: ";
28 const std::string kSpace = " ";
29 const std::string kNewLine = "\n";
30 const std::string kBinaryHeader = "$";
31 const std::string kNonce = " nonce=\"\"";
32 const std::string kRealm = " realm=\"\"";
33 const std::string kHeaderBoundary = "\r\n\r\n";
34 const std::string kContentLength = "content-length: ";
35 const std::string kDefaultRequestValue = "INVALID_FORMAT";
36 const std::string kUrlPrefix = "rtsp://root:pass@";
37 const std::string kRequestMarker = "REQUEST_SENT";
38 const std::string kQuitResponse = "\n\n\n\n";
39 const std::string kRTSPVersion = "RTSP/1.0";
40 const std::string kValidResponse = kRTSPVersion + " 200 \n";
41 const std::string kAuthString = kRTSPVersion + " 401 \nwww-authenticate: ";
42 constexpr char kNullValue = '\0';
43 constexpr char kDefaultValue = '0';
44 constexpr int32_t kWhat = 'resp';
45 constexpr int32_t kMinPort = 100;
46 constexpr int32_t kMaxPort = 999;
47 constexpr int32_t kMinASCIIValue = 32;
48 constexpr int32_t kMaxASCIIValue = 126;
49 constexpr int32_t kMinContentLength = 0;
50 constexpr int32_t kMaxContentLength = 1000;
51 constexpr int32_t kBinaryVectorSize = 3;
52 constexpr int32_t kDefaultCseqValue = 1;
53 constexpr int32_t kBufferSize = 1024;
54 constexpr int32_t kMaxLoopRuns = 5;
55 constexpr int32_t kPort = 554;
56 constexpr int32_t kMaxBytes = 128;
57 constexpr int32_t kMaxThreads = 1024;
59 struct FuzzAHandler : public AHandler {
60   public:
FuzzAHandlerFuzzAHandler61     FuzzAHandler(std::function<void()> signalEosFunction)
62         : mSignalEosFunction(std::move(signalEosFunction)) {}
63     ~FuzzAHandler() = default;
65   protected:
onMessageReceivedFuzzAHandler66     void onMessageReceived(const sp<AMessage>& msg) override {
67         switch (msg->what()) {
68             case kWhat: {
69                 mSignalEosFunction();
70                 break;
71             }
72         }
73     }
75   private:
76     std::function<void()> mSignalEosFunction;
77 };
79 class RTSPConnectionFuzzer {
80   public:
RTSPConnectionFuzzer(const uint8_t * data,size_t size)81     RTSPConnectionFuzzer(const uint8_t* data, size_t size) : mFdp(data, size){};
~RTSPConnectionFuzzer()82     ~RTSPConnectionFuzzer() {
83         // wait for all the threads to join the main thread
84         for (auto& thread : mThreadPool) {
85             if (thread.joinable()) {
86                 thread.join();
87             }
88         }
89         close(mServerFd);
90     }
91     void process();
93   private:
94     void signalEos();
95     void startServer();
96     void createFuzzData();
97     void acceptConnection();
98     void handleConnection(int32_t);
99     void handleClientResponse(int32_t);
100     void sendValidResponse(int32_t, int32_t);
101     int32_t checkSocket(int32_t);
102     size_t generateBinaryDataSize(std::string);
103     bool checkValidRequestData(const AString&);
104     bool mEosReached = false;
105     bool mServerFailure = false;
106     bool mNotifyResponseListener = false;
107     int32_t mServerFd;
108     std::string mFuzzData = "";
109     std::string mFuzzRequestData = "";
110     std::string mRequestData = kDefaultRequestValue;
111     std::mutex mFuzzDataMutex;
112     std::mutex mMsgPostCompleteMutex;
113     std::condition_variable mConditionalVariable;
114     std::vector<std::thread> mThreadPool;
115     FuzzedDataProvider mFdp;
116 };
generateBinaryDataSize(std::string values)118 size_t RTSPConnectionFuzzer::generateBinaryDataSize(std::string values) {
119     // computed the binary data size as done in ARTSPConnection.cpp
120     uint8_t x = values[0];
121     uint8_t y = values[1];
122     return x << 8 | y;
123 }
checkValidRequestData(const AString & request)125 bool RTSPConnectionFuzzer::checkValidRequestData(const AString& request) {
126     if (request.find(kHeaderBoundary.c_str()) <= 0) {
127         return false;
128     }
129     ssize_t space = request.find(kSpace.c_str());
130     if (space <= 0) {
131         return false;
132     }
133     if (request.find(kSpace.c_str(), space + 1) <= 0) {
134         return false;
135     }
136     return true;
137 }
createFuzzData()139 void RTSPConnectionFuzzer::createFuzzData() {
140     std::unique_lock fuzzLock(mFuzzDataMutex);
141     mFuzzData = "";
142     mFuzzRequestData = "";
143     int32_t contentLength = 0;
144     if (mFdp.ConsumeBool()) {
145         if (mFdp.ConsumeBool()) {
146             // if we want to handle server request
147             mFuzzData.append(kSpace + kSpace + kRTSPVersion);
148         } else {
149             // if we want to notify response listener
150             mFuzzData.append(
151                     kRTSPVersion + kSpace +
152                     std::to_string(mFdp.ConsumeIntegralInRange<uint16_t>(kMinPort, kMaxPort)) +
153                     kSpace);
154         }
155         mFuzzData.append(kNewLine);
156         if (mFdp.ConsumeBool()) {
157             contentLength =
158                     mFdp.ConsumeIntegralInRange<int32_t>(kMinContentLength, kMaxContentLength);
159             mFuzzData.append(kContentLength + std::to_string(contentLength) + kNewLine);
160             if (mFdp.ConsumeBool()) {
161                 mFdp.ConsumeBool() ? mFuzzData.append(kSpace + kNewLine)
162                                    : mFuzzData.append(kTab + kNewLine);
163             }
164         }
165         // new line to break out of infinite for loop
166         mFuzzData.append(kNewLine);
167         if (contentLength) {
168             std::string contentData = mFdp.ConsumeBytesAsString(contentLength);
169             contentData.resize(contentLength, kDefaultValue);
170             mFuzzData.append(contentData);
171         }
172     } else {
173         // for binary data
174         std::string randomValues(kBinaryVectorSize, kNullValue);
175         for (size_t idx = 0; idx < kBinaryVectorSize; ++idx) {
176             randomValues[idx] =
177                     (char)mFdp.ConsumeIntegralInRange<uint8_t>(kMinASCIIValue, kMaxASCIIValue);
178         }
179         size_t binaryDataSize = generateBinaryDataSize(randomValues);
180         std::string data = mFdp.ConsumeBytesAsString(binaryDataSize);
181         data.resize(binaryDataSize, kDefaultValue);
182         mFuzzData.append(kBinaryHeader + randomValues + data);
183     }
184     if (mFdp.ConsumeBool()) {
185         mRequestData = mFdp.ConsumeRandomLengthString(kMaxBytes) + kSpace + kSpace +
186                        kHeaderBoundary + mFdp.ConsumeRandomLengthString(kMaxBytes);
187         // Check if Request data is valid
188         if (checkValidRequestData(mRequestData.c_str())) {
189             if (mFdp.ConsumeBool()) {
190                 if (mFdp.ConsumeBool()) {
191                     // if we want to handle server request
192                     mFuzzRequestData.append(kSpace + kSpace + kRTSPVersion + kNewLine);
193                 } else {
194                     // if we want to add authentication headers
195                     mNotifyResponseListener = true;
196                     mFuzzRequestData.append(kAuthString);
197                     if (mFdp.ConsumeBool()) {
198                         // for Authentication type: Basic
199                         mFuzzRequestData.append(kAuthType[0]);
200                     } else {
201                         // for Authentication type: Digest
202                         mFuzzRequestData.append(kAuthType[1]);
203                         mFuzzRequestData.append(kNonce);
204                         mFuzzRequestData.append(kRealm);
205                     }
206                     mFuzzRequestData.append(kNewLine);
207                 }
208             } else {
209                 mNotifyResponseListener = false;
210                 mFuzzRequestData.append(kValidResponse);
211             }
212         } else {
213             mRequestData = kDefaultRequestValue;
214         }
215     } else {
216         mRequestData = kDefaultRequestValue;
217         mFuzzData.append(kNewLine);
218     }
219 }
signalEos()221 void RTSPConnectionFuzzer::signalEos() {
222     mEosReached = true;
223     mConditionalVariable.notify_all();
224     return;
225 }
checkSocket(int32_t newSocket)227 int32_t RTSPConnectionFuzzer::checkSocket(int32_t newSocket) {
228     struct timeval tv;
229     tv.tv_sec = 1;
230     tv.tv_usec = 0;
232     fd_set rs;
233     FD_ZERO(&rs);
234     FD_SET(newSocket, &rs);
236     return select(newSocket + 1, &rs, nullptr, nullptr, &tv);
237 }
sendValidResponse(int32_t newSocket,int32_t cseq=-1)239 void RTSPConnectionFuzzer::sendValidResponse(int32_t newSocket, int32_t cseq = -1) {
240     std::string validResponse = kValidResponse;
241     if (cseq != -1) {
242         validResponse.append(kCSeq + std::to_string(cseq));
243         validResponse.append(kNewLine + kNewLine);
244     } else {
245         validResponse.append(kNewLine);
246     }
247     send(newSocket, validResponse.c_str(), validResponse.size(), 0);
248 }
handleClientResponse(int32_t newSocket)250 void RTSPConnectionFuzzer::handleClientResponse(int32_t newSocket) {
251     char buffer[kBufferSize] = {0};
252     if (checkSocket(newSocket) == 1) {
253         read(newSocket, buffer, kBufferSize);
254     }
255 }
handleConnection(int32_t newSocket)257 void RTSPConnectionFuzzer::handleConnection(int32_t newSocket) {
258     std::unique_lock fuzzLock(mFuzzDataMutex);
259     send(newSocket, mFuzzData.c_str(), mFuzzData.size(), 0);
260     if (mFuzzData[0] == kSpace[0]) {
261         handleClientResponse(newSocket);
262     }
264     if (mFuzzRequestData != "") {
265         char buffer[kBufferSize] = {0};
266         if (checkSocket(newSocket) == 1 && recv(newSocket, buffer, kBufferSize, MSG_DONTWAIT) > 0) {
267             // Extract the 'CSeq' value present at the end of header
268             std::string clientResponse(buffer);
269             std::string header = clientResponse.substr(0, clientResponse.find(kHeaderBoundary));
270             char cseq = header[header.rfind(kCSeq) + kCSeq.length()];
271             int32_t cseqValue = cseq ? cseq - '0' : kDefaultCseqValue;
272             std::string response = mFuzzRequestData;
273             response.append(kCSeq + std::to_string(cseqValue));
274             response.append(kNewLine + kNewLine);
275             send(newSocket, response.data(), response.length(), 0);
277             if (!mNotifyResponseListener) {
278                 char buffer[kBufferSize] = {0};
279                 if (checkSocket(newSocket) == 1) {
280                     if (recv(newSocket, buffer, kBufferSize, MSG_DONTWAIT) > 0) {
281                         // Extract the 'CSeq' value present at the end of header
282                         std::string clientResponse(buffer);
283                         std::string header =
284                                 clientResponse.substr(0, clientResponse.find(kHeaderBoundary));
285                         char cseq = header[header.rfind(kCSeq) + kCSeq.length()];
286                         int32_t cseqValue = cseq ? cseq - '0' : kDefaultCseqValue;
287                         sendValidResponse(newSocket, cseqValue);
288                     } else {
289                         sendValidResponse(newSocket);
290                     }
291                 }
292             }
293         } else {
294             // If no data to read, then send a valid response
295             // to release the mutex lock in fuzzer
296             sendValidResponse(newSocket);
297         }
298     }
299     send(newSocket, kQuitResponse.c_str(), kQuitResponse.size(), 0);
300 }
startServer()302 void RTSPConnectionFuzzer::startServer() {
303     signal(SIGPIPE, SIG_IGN);
304     mServerFd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0);
305     struct sockaddr_in serverAddress;
306     serverAddress.sin_family = AF_INET;
307     serverAddress.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
308     serverAddress.sin_port = htons(kPort);
310     // Get rid of "Address in use" error
311     int32_t opt = 1;
312     if (setsockopt(mServerFd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) {
313         mServerFailure = true;
314     }
316     // Bind the socket and set for listening.
317     if (bind(mServerFd, (struct sockaddr*)(&serverAddress), sizeof(serverAddress)) < 0) {
318         mServerFailure = true;
319     }
321     if (listen(mServerFd, 5) < 0) {
322         mServerFailure = true;
323     }
324 }
acceptConnection()326 void RTSPConnectionFuzzer::acceptConnection() {
327     int32_t clientFd = accept4(mServerFd, nullptr, nullptr, SOCK_CLOEXEC);
328     handleConnection(clientFd);
329     close(clientFd);
330 }
process()332 void RTSPConnectionFuzzer::process() {
333     startServer();
334     if (mServerFailure) {
335         return;
336     }
337     sp<ALooper> looper = sp<ALooper>::make();
338     sp<FuzzAHandler> handler =
339             sp<FuzzAHandler>::make(std::bind(&RTSPConnectionFuzzer::signalEos, this));
340     sp<ARTSPConnection> rtspConnection =
341             sp<ARTSPConnection>::make(mFdp.ConsumeBool(), mFdp.ConsumeIntegral<uint64_t>());
342     looper->start();
343     looper->registerHandler(rtspConnection);
344     looper->registerHandler(handler);
345     sp<AMessage> replymsg = sp<AMessage>::make(kWhat, handler);
346     std::string url = kUrlPrefix + std::to_string(kPort) + "/";
348     while (mFdp.remaining_bytes() && mThreadPool.size() < kMaxThreads) {
349         createFuzzData();
350         mThreadPool.push_back(std::thread(&RTSPConnectionFuzzer::acceptConnection, this));
351         if (mFdp.ConsumeBool()) {
352             rtspConnection->observeBinaryData(replymsg);
353         }
355         {
356             rtspConnection->connect(url.c_str(), replymsg);
357             std::unique_lock waitForMsgPostComplete(mMsgPostCompleteMutex);
358             mConditionalVariable.wait(waitForMsgPostComplete, [this] {
359                 if (mEosReached == true) {
360                     mEosReached = false;
361                     return true;
362                 }
363                 return mEosReached;
364             });
365         }
367         if (mRequestData != kDefaultRequestValue) {
368             rtspConnection->sendRequest(mRequestData.c_str(), replymsg);
369             std::unique_lock waitForMsgPostComplete(mMsgPostCompleteMutex);
370             mConditionalVariable.wait(waitForMsgPostComplete, [this] {
371                 if (mEosReached == true) {
372                     mEosReached = false;
373                     return true;
374                 }
375                 return mEosReached;
376             });
377         }
379         if (mFdp.ConsumeBool()) {
380             rtspConnection->disconnect(replymsg);
381             std::unique_lock waitForMsgPostComplete(mMsgPostCompleteMutex);
382             mConditionalVariable.wait(waitForMsgPostComplete, [this] {
383                 if (mEosReached == true) {
384                     mEosReached = false;
385                     return true;
386                 }
387                 return mEosReached;
388             });
389         }
390     }
391 }
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)393 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
394     RTSPConnectionFuzzer rtspFuzz(data, size);
395     rtspFuzz.process();
396     return 0;
397 }