1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NETUTILS_OPERATIONLIMITER_H
18 #define NETUTILS_OPERATIONLIMITER_H
19 
20 #include <mutex>
21 #include <unordered_map>
22 
23 #include <android-base/logging.h>
24 #include <android-base/thread_annotations.h>
25 
26 #include "Experiments.h"
27 
28 namespace android {
29 namespace netdutils {
30 
31 // Limits the number of outstanding DNS queries by client UID.
32 constexpr int MAX_QUERIES_PER_UID = 256;
33 // Limits the total number of outstanding DNS queries.
34 constexpr int MAX_QUERIES_IN_TOTAL = 2500;
35 
36 // Tracks the number of operations in progress on behalf of a particular key or
37 // ID, rejecting further attempts to start new operations after a configurable
38 // limit has been reached.
39 //
40 // The intended usage pattern is:
41 //     OperationLimiter<UserId> connections_per_user;
42 //     ...
43 //     int connectToSomeResource(int user) {
44 //         if (!connections_per_user.start(user)) return TRY_AGAIN_LATER;
45 //         // ...do expensive work here...
46 //         connections_per_user.finish(user);
47 //     }
48 //
49 // This class is thread-safe.
50 template <typename KeyType>
51 class OperationLimiter {
52   public:
OperationLimiter(int limitPerKey)53     OperationLimiter(int limitPerKey) : mLimitPerKey(limitPerKey) {}
54 
~OperationLimiter()55     ~OperationLimiter() {
56         DCHECK(mCounters.empty()) << "Destroying OperationLimiter with active operations";
57     }
58 
59     // Returns false if |key| has reached the maximum number of concurrent operations,
60     // or if the global limit has been reached. Otherwise, increments the counter and returns true.
61     //
62     // Note: each successful start(key) must be matched by exactly one call to
63     // finish(key).
EXCLUDES(mMutex)64     bool start(KeyType key, int globalLimit = MAX_QUERIES_IN_TOTAL) EXCLUDES(mMutex) {
65         std::lock_guard lock(mMutex);
66         if (globalLimit < mLimitPerKey) {
67             LOG(ERROR) << "Misconfiguration on max_queries_global " << globalLimit;
68             globalLimit = MAX_QUERIES_IN_TOTAL;
69         }
70         if (mGlobalCounter >= globalLimit) {
71             // Oh, no!
72             LOG(ERROR) << "Query from " << key << " denied due to global limit: " << globalLimit;
73             return false;
74         }
75 
76         auto& cnt = mCounters[key];  // operator[] creates new entries as needed.
77         if (cnt >= mLimitPerKey) {
78             // Oh, no!
79             LOG(ERROR) << "Query from " << key << " denied due to limit: " << mLimitPerKey;
80             return false;
81         }
82 
83         ++cnt;
84         ++mGlobalCounter;
85         return true;
86     }
87 
88     // Decrements the number of operations in progress accounted to |key|.
89     // See usage notes on start().
finish(KeyType key)90     void finish(KeyType key) EXCLUDES(mMutex) {
91         std::lock_guard lock(mMutex);
92 
93         --mGlobalCounter;
94         if (mGlobalCounter < 0) {
95             LOG(FATAL_WITHOUT_ABORT) << "Global operations counter going negative, this is a bug.";
96             return;
97         }
98 
99         auto it = mCounters.find(key);
100         if (it == mCounters.end()) {
101             LOG(FATAL_WITHOUT_ABORT) << "Decremented non-existent counter for key=" << key;
102             return;
103         }
104         auto& cnt = it->second;
105         --cnt;
106         if (cnt <= 0) {
107             // Cleanup counters once they drop down to zero.
108             mCounters.erase(it);
109         }
110     }
111 
112   private:
113     // Protects access to the map below.
114     std::mutex mMutex;
115 
116     // Tracks the number of outstanding queries by key.
117     std::unordered_map<KeyType, int> mCounters GUARDED_BY(mMutex);
118 
119     int mGlobalCounter GUARDED_BY(mMutex) = 0;
120 
121     // Maximum number of outstanding queries from a single key.
122     const int mLimitPerKey;
123 };
124 
125 }  // namespace netdutils
126 }  // namespace android
127 
128 #endif  // NETUTILS_OPERATIONLIMITER_H
129