1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.net
18 
19 import android.net.DataUsageRequest
20 import android.net.netstats.IUsageCallback
21 import android.os.IBinder
22 import java.util.concurrent.LinkedBlockingQueue
23 import java.util.concurrent.TimeUnit
24 import kotlin.test.fail
25 
26 private const val DEFAULT_TIMEOUT_MS = 200L
27 
28 // TODO: Move the class to static libs once all downstream have IUsageCallback definition.
29 class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() {
30     sealed class CallbackType(val request: DataUsageRequest) {
31         class OnThresholdReached(request: DataUsageRequest) : CallbackType(request)
32         class OnCallbackReleased(request: DataUsageRequest) : CallbackType(request)
33     }
34 
35     // TODO: Change to use ArrayTrackRecord once moved into to the module.
36     private val history = LinkedBlockingQueue<CallbackType>()
37 
onThresholdReachednull38     override fun onThresholdReached(request: DataUsageRequest) {
39         history.add(CallbackType.OnThresholdReached(request))
40     }
41 
onCallbackReleasednull42     override fun onCallbackReleased(request: DataUsageRequest) {
43         history.add(CallbackType.OnCallbackReleased(request))
44     }
45 
expectOnThresholdReachednull46     fun expectOnThresholdReached(request: DataUsageRequest) {
47         expectCallback<CallbackType.OnThresholdReached>(request, DEFAULT_TIMEOUT_MS)
48     }
49 
expectOnCallbackReleasednull50     fun expectOnCallbackReleased(request: DataUsageRequest) {
51         expectCallback<CallbackType.OnCallbackReleased>(request, DEFAULT_TIMEOUT_MS)
52     }
53 
54     @JvmOverloads
assertNoCallbacknull55     fun assertNoCallback(timeout: Long = DEFAULT_TIMEOUT_MS) {
56         val cb = history.poll(timeout, TimeUnit.MILLISECONDS)
57         cb?.let { fail("Expected no callback but got $cb") }
58     }
59 
60     // Expects a callback of the specified request on the specified network within the timeout.
61     // If no callback arrives, or a different callback arrives, fail.
expectCallbacknull62     private inline fun <reified T : CallbackType> expectCallback(
63         expectedRequest: DataUsageRequest,
64         timeoutMs: Long
65     ) {
66         history.poll(timeoutMs, TimeUnit.MILLISECONDS).let {
67             if (it !is T || it.request != expectedRequest) {
68                 fail("Unexpected callback : $it," +
69                         " expected ${T::class} with Request[$expectedRequest]")
70             } else {
71                 it
72             }
73         }
74     }
75 
asBindernull76     override fun asBinder(): IBinder {
77         return binder
78     }
79 }