1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "PendingRequestPool.h"
18 
19 #include <VehicleHalTypes.h>
20 #include <VehicleUtils.h>
21 
22 #include <utils/Log.h>
23 #include <utils/SystemClock.h>
24 
25 #include <vector>
26 
27 namespace android {
28 namespace hardware {
29 namespace automotive {
30 namespace vehicle {
31 
32 namespace {
33 
34 using ::aidl::android::hardware::automotive::vehicle::StatusCode;
35 using ::android::base::Result;
36 
37 // At least check every 1s.
38 constexpr int64_t CHECK_TIME_IN_NANO = 1'000'000'000;
39 
40 }  // namespace
41 
PendingRequestPool(int64_t timeoutInNano)42 PendingRequestPool::PendingRequestPool(int64_t timeoutInNano) : mTimeoutInNano(timeoutInNano) {
43     mThread = std::thread([this] {
44         // [this] must be alive within this thread because destructor would wait for this thread
45         // to exit.
46         int64_t sleepTime = std::min(mTimeoutInNano, static_cast<int64_t>(CHECK_TIME_IN_NANO));
47         std::unique_lock<std::mutex> lk(mCvLock);
48         while (!mCv.wait_for(lk, std::chrono::nanoseconds(sleepTime),
49                              [this] { return mThreadStop; })) {
50             checkTimeout();
51         }
52     });
53 }
54 
~PendingRequestPool()55 PendingRequestPool::~PendingRequestPool() {
56     {
57         // Even if the shared variable is atomic, it must be modified under the
58         // mutex in order to correctly publish the modification to the waiting
59         // thread.
60         std::unique_lock<std::mutex> lk(mCvLock);
61         mThreadStop = true;
62     }
63     mCv.notify_all();
64     if (mThread.joinable()) {
65         mThread.join();
66     }
67 
68     // If this pool is being destructed, send out all pending requests as timeout.
69     {
70         std::scoped_lock<std::mutex> lockGuard(mLock);
71 
72         for (auto& [_, pendingRequests] : mPendingRequestsByClient) {
73             for (const auto& request : pendingRequests) {
74                 (*request.callback)(request.requestIds);
75             }
76         }
77         mPendingRequestsByClient.clear();
78     }
79 }
80 
addRequests(const void * clientId,const std::unordered_set<int64_t> & requestIds,std::shared_ptr<const TimeoutCallbackFunc> callback)81 VhalResult<void> PendingRequestPool::addRequests(
82         const void* clientId, const std::unordered_set<int64_t>& requestIds,
83         std::shared_ptr<const TimeoutCallbackFunc> callback) {
84     std::scoped_lock<std::mutex> lockGuard(mLock);
85     std::list<PendingRequest>* pendingRequests;
86     size_t pendingRequestCount = 0;
87     if (mPendingRequestsByClient.find(clientId) != mPendingRequestsByClient.end()) {
88         pendingRequests = &mPendingRequestsByClient[clientId];
89         for (const auto& pendingRequest : *pendingRequests) {
90             const auto& pendingRequestIds = pendingRequest.requestIds;
91             for (int64_t requestId : requestIds) {
92                 if (pendingRequestIds.find(requestId) != pendingRequestIds.end()) {
93                     return StatusError(StatusCode::INVALID_ARG)
94                            << "duplicate request ID: " << requestId;
95                 }
96             }
97             pendingRequestCount += pendingRequestIds.size();
98         }
99     } else {
100         // Create a new empty list for this client.
101         pendingRequests = &mPendingRequestsByClient[clientId];
102     }
103 
104     if (requestIds.size() > MAX_PENDING_REQUEST_PER_CLIENT - pendingRequestCount) {
105         return StatusError(StatusCode::TRY_AGAIN) << "too many pending requests";
106     }
107 
108     int64_t currentTime = elapsedRealtimeNano();
109     int64_t timeoutTimestamp = currentTime + mTimeoutInNano;
110 
111     pendingRequests->push_back({
112             .requestIds = std::unordered_set<int64_t>(requestIds.begin(), requestIds.end()),
113             .timeoutTimestamp = timeoutTimestamp,
114             .callback = callback,
115     });
116 
117     return {};
118 }
119 
isRequestPending(const void * clientId,int64_t requestId) const120 bool PendingRequestPool::isRequestPending(const void* clientId, int64_t requestId) const {
121     std::scoped_lock<std::mutex> lockGuard(mLock);
122 
123     return isRequestPendingLocked(clientId, requestId);
124 }
125 
countPendingRequests() const126 size_t PendingRequestPool::countPendingRequests() const {
127     std::scoped_lock<std::mutex> lockGuard(mLock);
128 
129     size_t count = 0;
130     for (const auto& [clientId, requests] : mPendingRequestsByClient) {
131         for (const auto& request : requests) {
132             count += request.requestIds.size();
133         }
134     }
135     return count;
136 }
137 
countPendingRequests(const void * clientId) const138 size_t PendingRequestPool::countPendingRequests(const void* clientId) const {
139     std::scoped_lock<std::mutex> lockGuard(mLock);
140 
141     auto it = mPendingRequestsByClient.find(clientId);
142     if (it == mPendingRequestsByClient.end()) {
143         return 0;
144     }
145 
146     size_t count = 0;
147     for (const auto& pendingRequest : it->second) {
148         count += pendingRequest.requestIds.size();
149     }
150 
151     return count;
152 }
153 
isRequestPendingLocked(const void * clientId,int64_t requestId) const154 bool PendingRequestPool::isRequestPendingLocked(const void* clientId, int64_t requestId) const {
155     auto it = mPendingRequestsByClient.find(clientId);
156     if (it == mPendingRequestsByClient.end()) {
157         return false;
158     }
159     for (const auto& pendingRequest : it->second) {
160         const auto& requestIds = pendingRequest.requestIds;
161         if (requestIds.find(requestId) != requestIds.end()) {
162             return true;
163         }
164     }
165     return false;
166 }
167 
checkTimeout()168 void PendingRequestPool::checkTimeout() {
169     std::vector<PendingRequest> timeoutRequests;
170     {
171         std::scoped_lock<std::mutex> lockGuard(mLock);
172 
173         int64_t currentTime = elapsedRealtimeNano();
174 
175         std::vector<const void*> clientsWithEmptyRequests;
176 
177         for (auto& [clientId, pendingRequests] : mPendingRequestsByClient) {
178             auto it = pendingRequests.begin();
179             while (it != pendingRequests.end()) {
180                 if (it->timeoutTimestamp >= currentTime) {
181                     break;
182                 }
183                 timeoutRequests.push_back(std::move(*it));
184                 it = pendingRequests.erase(it);
185             }
186 
187             if (pendingRequests.empty()) {
188                 clientsWithEmptyRequests.push_back(clientId);
189             }
190         }
191 
192         for (const void* clientId : clientsWithEmptyRequests) {
193             mPendingRequestsByClient.erase(clientId);
194         }
195     }
196 
197     // Call the callback outside the lock.
198     for (const auto& request : timeoutRequests) {
199         (*request.callback)(request.requestIds);
200     }
201 }
202 
tryFinishRequests(const void * clientId,const std::unordered_set<int64_t> & requestIds)203 std::unordered_set<int64_t> PendingRequestPool::tryFinishRequests(
204         const void* clientId, const std::unordered_set<int64_t>& requestIds) {
205     std::scoped_lock<std::mutex> lockGuard(mLock);
206 
207     std::unordered_set<int64_t> foundIds;
208 
209     if (mPendingRequestsByClient.find(clientId) == mPendingRequestsByClient.end()) {
210         return foundIds;
211     }
212 
213     auto& pendingRequests = mPendingRequestsByClient[clientId];
214     auto it = pendingRequests.begin();
215     while (it != pendingRequests.end()) {
216         auto& pendingRequestIds = it->requestIds;
217         for (int64_t requestId : requestIds) {
218             auto idIt = pendingRequestIds.find(requestId);
219             if (idIt == pendingRequestIds.end()) {
220                 continue;
221             }
222             pendingRequestIds.erase(idIt);
223             foundIds.insert(requestId);
224         }
225         if (pendingRequestIds.empty()) {
226             it = pendingRequests.erase(it);
227             continue;
228         }
229         it++;
230     }
231 
232     return foundIds;
233 }
234 
235 }  // namespace vehicle
236 }  // namespace automotive
237 }  // namespace hardware
238 }  // namespace android
239