1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils.async;
18 
19 import com.android.net.module.util.async.OsAccess;
20 
21 import java.util.Arrays;
22 
23 /**
24  * Limits the number of bytes processed to the given maximum of bytes per second.
25  *
26  * The limiter tracks the total for the past second, along with sums for each 10ms
27  * in the past second, allowing the total to be adjusted as the time passes.
28  */
29 public final class RateLimiter {
30     private static final int PERIOD_DURATION_MS = 1000;
31     private static final int BUCKET_COUNT = 100;
32 
33     public static final int BUCKET_DURATION_MS = PERIOD_DURATION_MS / BUCKET_COUNT;
34 
35     private final OsAccess mOsAccess;
36     private final int[] mStatBuckets = new int[BUCKET_COUNT];
37     private int mMaxPerPeriodBytes;
38     private int mMaxPerBucketBytes;
39     private int mRecordedPeriodBytes;
40     private long mLastLimitTimestamp;
41     private int mLastRequestReduction;
42 
RateLimiter(OsAccess osAccess, int bytesPerSecond)43     public RateLimiter(OsAccess osAccess, int bytesPerSecond) {
44         mOsAccess = osAccess;
45         setBytesPerSecond(bytesPerSecond);
46         clear();
47     }
48 
getBytesPerSecond()49     public int getBytesPerSecond() {
50         return mMaxPerPeriodBytes;
51     }
52 
setBytesPerSecond(int bytesPerSecond)53     public void setBytesPerSecond(int bytesPerSecond) {
54         mMaxPerPeriodBytes = bytesPerSecond;
55         mMaxPerBucketBytes = Math.max(1, (mMaxPerPeriodBytes / BUCKET_COUNT) * 2);
56     }
57 
clear()58     public void clear() {
59         mLastLimitTimestamp = mOsAccess.monotonicTimeMillis();
60         mRecordedPeriodBytes = 0;
61         Arrays.fill(mStatBuckets, 0);
62     }
63 
limit(RateLimiter limiter1, RateLimiter limiter2, int requestedBytes)64     public static int limit(RateLimiter limiter1, RateLimiter limiter2, int requestedBytes) {
65         final long now = limiter1.mOsAccess.monotonicTimeMillis();
66         final int allowedCount = Math.min(limiter1.calculateLimit(now, requestedBytes),
67             limiter2.calculateLimit(now, requestedBytes));
68         limiter1.recordBytes(now, requestedBytes, allowedCount);
69         limiter2.recordBytes(now, requestedBytes, allowedCount);
70         return allowedCount;
71     }
72 
limit(int requestedBytes)73     public int limit(int requestedBytes) {
74         final long now = mOsAccess.monotonicTimeMillis();
75         final int allowedCount = calculateLimit(now, requestedBytes);
76         recordBytes(now, requestedBytes, allowedCount);
77         return allowedCount;
78     }
79 
getLastRequestReduction()80     public int getLastRequestReduction() {
81         return mLastRequestReduction;
82     }
83 
acceptAllOrNone(int requestedBytes)84     public boolean acceptAllOrNone(int requestedBytes) {
85         final long now = mOsAccess.monotonicTimeMillis();
86         final int allowedCount = calculateLimit(now, requestedBytes);
87         if (allowedCount < requestedBytes) {
88             return false;
89         }
90         recordBytes(now, requestedBytes, allowedCount);
91         return true;
92     }
93 
calculateLimit(long now, int requestedBytes)94     private int calculateLimit(long now, int requestedBytes) {
95         // First remove all stale bucket data and adjust the total.
96         final long currentBucketAbsIdx = now / BUCKET_DURATION_MS;
97         final long staleCutoffIdx = currentBucketAbsIdx - BUCKET_COUNT;
98         for (long i = mLastLimitTimestamp / BUCKET_DURATION_MS; i < staleCutoffIdx; i++) {
99             final int idx = (int) (i % BUCKET_COUNT);
100             mRecordedPeriodBytes -= mStatBuckets[idx];
101             mStatBuckets[idx] = 0;
102         }
103 
104         final int bucketIdx = (int) (currentBucketAbsIdx % BUCKET_COUNT);
105         final int maxAllowed = Math.min(mMaxPerPeriodBytes - mRecordedPeriodBytes,
106             Math.min(mMaxPerBucketBytes - mStatBuckets[bucketIdx], requestedBytes));
107         return Math.max(0, maxAllowed);
108     }
109 
recordBytes(long now, int requestedBytes, int actualBytes)110     private void recordBytes(long now, int requestedBytes, int actualBytes) {
111         mStatBuckets[(int) ((now / BUCKET_DURATION_MS) % BUCKET_COUNT)] += actualBytes;
112         mRecordedPeriodBytes += actualBytes;
113         mLastRequestReduction = requestedBytes - actualBytes;
114         mLastLimitTimestamp = now;
115     }
116 
117     @Override
toString()118     public String toString() {
119         StringBuilder sb = new StringBuilder();
120         sb.append("{max=");
121         sb.append(mMaxPerPeriodBytes);
122         sb.append(",max_bucket=");
123         sb.append(mMaxPerBucketBytes);
124         sb.append(",total=");
125         sb.append(mRecordedPeriodBytes);
126         sb.append(",last_red=");
127         sb.append(mLastRequestReduction);
128         sb.append('}');
129         return sb.toString();
130     }
131 }
132