1 /*
<lambda>null2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net
18 
19 import android.app.Instrumentation
20 import android.content.Context
21 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED
22 import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
23 import android.net.NetworkCapabilities.TRANSPORT_TEST
24 import android.net.NetworkProviderTest.TestNetworkCallback.CallbackEntry.OnUnavailable
25 import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequestWithdrawn
26 import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequested
27 import android.os.Build
28 import android.os.Handler
29 import android.os.HandlerThread
30 import android.os.Looper
31 import android.util.Log
32 import androidx.test.InstrumentationRegistry
33 import com.android.modules.utils.build.SdkLevel.isAtLeastS
34 import com.android.net.module.util.ArrayTrackRecord
35 import com.android.testutils.CompatUtil
36 import com.android.testutils.ConnectivityModuleTest
37 import com.android.testutils.DevSdkIgnoreRule
38 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
39 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
40 import com.android.testutils.DevSdkIgnoreRunner
41 import com.android.testutils.TestableNetworkOfferCallback
42 import java.util.UUID
43 import java.util.concurrent.Executor
44 import java.util.concurrent.RejectedExecutionException
45 import kotlin.test.assertEquals
46 import kotlin.test.assertNotEquals
47 import kotlin.test.fail
48 import org.junit.After
49 import org.junit.Before
50 import org.junit.Rule
51 import org.junit.Test
52 import org.junit.runner.RunWith
53 import org.mockito.Mockito.doReturn
54 import org.mockito.Mockito.mock
55 import org.mockito.Mockito.verifyNoMoreInteractions
56 
57 private const val DEFAULT_TIMEOUT_MS = 5000L
58 private const val DEFAULT_NO_CALLBACK_TIMEOUT_MS = 200L
59 private val instrumentation: Instrumentation
60     get() = InstrumentationRegistry.getInstrumentation()
61 private val context: Context get() = InstrumentationRegistry.getContext()
62 private val PROVIDER_NAME = "NetworkProviderTest"
63 
64 @RunWith(DevSdkIgnoreRunner::class)
65 @ConnectivityModuleTest
66 class NetworkProviderTest {
67     @Rule @JvmField
68     val mIgnoreRule = DevSdkIgnoreRule()
69     private val mCm = context.getSystemService(ConnectivityManager::class.java)!!
70     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
71 
72     @Before
73     fun setUp() {
74         instrumentation.getUiAutomation().adoptShellPermissionIdentity()
75         mHandlerThread.start()
76     }
77 
78     @After
79     fun tearDown() {
80         mHandlerThread.quitSafely()
81         mHandlerThread.join()
82         instrumentation.getUiAutomation().dropShellPermissionIdentity()
83     }
84 
85     private class TestNetworkProvider(context: Context, looper: Looper) :
86             NetworkProvider(context, looper, PROVIDER_NAME) {
87         private val TAG = this::class.simpleName
88         private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()
89 
90         sealed class CallbackEntry {
91             data class OnNetworkRequested(
92                 val request: NetworkRequest,
93                 val score: Int,
94                 val id: Int
95             ) : CallbackEntry()
96             data class OnNetworkRequestWithdrawn(val request: NetworkRequest) : CallbackEntry()
97         }
98 
99         override fun onNetworkRequested(request: NetworkRequest, score: Int, id: Int) {
100             Log.d(TAG, "onNetworkRequested $request, $score, $id")
101             seenEvents.add(OnNetworkRequested(request, score, id))
102         }
103 
104         override fun onNetworkRequestWithdrawn(request: NetworkRequest) {
105             Log.d(TAG, "onNetworkRequestWithdrawn $request")
106             seenEvents.add(OnNetworkRequestWithdrawn(request))
107         }
108 
109         inline fun <reified T : CallbackEntry> eventuallyExpectCallbackThat(
110             crossinline predicate: (T) -> Boolean
111         ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
112                 ?: fail("Did not receive callback after ${DEFAULT_TIMEOUT_MS}ms")
113 
114         fun assertNoCallback() {
115             val cb = seenEvents.poll(DEFAULT_NO_CALLBACK_TIMEOUT_MS)
116             if (null != cb) fail("Expected no callback but got $cb")
117         }
118     }
119 
120     private fun createNetworkProvider(ctx: Context = context): TestNetworkProvider {
121         return TestNetworkProvider(ctx, mHandlerThread.looper)
122     }
123 
124     private fun createAndRegisterNetworkProvider(ctx: Context = context) =
125         createNetworkProvider(ctx).also {
126             assertEquals(it.getProviderId(), NetworkProvider.ID_NONE)
127             mCm.registerNetworkProvider(it)
128             assertNotEquals(it.getProviderId(), NetworkProvider.ID_NONE)
129         }
130 
131     // In S+ framework, do not run this test, since the provider will no longer receive
132     // onNetworkRequested for every request. Instead, provider needs to
133     // call {@code registerNetworkOffer} with the description of networks they
134     // might have ability to setup, and expects {@link NetworkOfferCallback#onNetworkNeeded}.
135     @IgnoreAfter(Build.VERSION_CODES.R)
136     @Test
137     fun testOnNetworkRequested() {
138         val provider = createAndRegisterNetworkProvider()
139 
140         val specifier = CompatUtil.makeTestNetworkSpecifier(
141                 UUID.randomUUID().toString())
142         // Test network is not allowed to be trusted.
143         val nr: NetworkRequest = NetworkRequest.Builder()
144                 .addTransportType(TRANSPORT_TEST)
145                 .removeCapability(NET_CAPABILITY_TRUSTED)
146                 .setNetworkSpecifier(specifier)
147                 .build()
148         val cb = ConnectivityManager.NetworkCallback()
149         mCm.requestNetwork(nr, cb)
150         provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
151             callback.request.getNetworkSpecifier() == specifier &&
152             callback.request.hasTransport(TRANSPORT_TEST)
153         }
154 
155         val initialScore = 40
156         val updatedScore = 60
157         val nc = NetworkCapabilities().apply {
158                 addTransportType(NetworkCapabilities.TRANSPORT_TEST)
159                 removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
160                 removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
161                 addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
162                 addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
163                 addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
164                 setNetworkSpecifier(specifier)
165         }
166         val lp = LinkProperties()
167         val config = NetworkAgentConfig.Builder().build()
168         val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc, lp,
169                 initialScore, config, provider) {}
170         agent.register()
171         agent.markConnected()
172 
173         provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
174             callback.request.getNetworkSpecifier() == specifier &&
175             callback.score == initialScore &&
176             callback.id == agent.providerId
177         }
178 
179         agent.sendNetworkScore(updatedScore)
180         provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
181             callback.request.getNetworkSpecifier() == specifier &&
182             callback.score == updatedScore &&
183             callback.id == agent.providerId
184         }
185 
186         mCm.unregisterNetworkCallback(cb)
187         provider.eventuallyExpectCallbackThat<OnNetworkRequestWithdrawn>() { callback ->
188             callback.request.getNetworkSpecifier() == specifier &&
189             callback.request.hasTransport(TRANSPORT_TEST)
190         }
191         mCm.unregisterNetworkProvider(provider)
192         // Provider id should be ID_NONE after unregister network provider
193         assertEquals(provider.getProviderId(), NetworkProvider.ID_NONE)
194         // unregisterNetworkProvider should not crash even if it's called on an
195         // already unregistered provider.
196         mCm.unregisterNetworkProvider(provider)
197     }
198 
199     // Mainline module can't use internal HandlerExecutor, so add an identical executor here.
200     // TODO: Refactor with the one in MultiNetworkPolicyTracker.
201     private class HandlerExecutor(private val handler: Handler) : Executor {
202         public override fun execute(command: Runnable) {
203             if (!handler.post(command)) {
204                 throw RejectedExecutionException(handler.toString() + " is shutting down")
205             }
206         }
207     }
208 
209     @IgnoreUpTo(Build.VERSION_CODES.R)
210     @Test
211     fun testRegisterNetworkOffer() {
212         val provider = createAndRegisterNetworkProvider()
213         val provider2 = createAndRegisterNetworkProvider()
214 
215         // Prepare the materials which will be used to create different offers.
216         val specifier1 = CompatUtil.makeTestNetworkSpecifier("TEST-SPECIFIER-1")
217         val specifier2 = CompatUtil.makeTestNetworkSpecifier("TEST-SPECIFIER-2")
218         val scoreWeaker = NetworkScore.Builder().build()
219         val scoreStronger = NetworkScore.Builder().setTransportPrimary(true).build()
220         val ncFilter1 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
221                 .setNetworkSpecifier(specifier1).build()
222         val ncFilter2 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
223                 .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
224                 .setNetworkSpecifier(specifier1).build()
225         val ncFilter3 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
226                 .setNetworkSpecifier(specifier2).build()
227         val ncFilter4 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
228                 .setNetworkSpecifier(specifier2).build()
229 
230         // Make 4 offers, where 1 doesn't have NOT_VCN, 2 has NOT_VCN, 3 is similar to 1 but with
231         // different specifier, and 4 is also similar to 1 but with different provider.
232         val offerCallback1 = TestableNetworkOfferCallback(
233                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
234         val offerCallback2 = TestableNetworkOfferCallback(
235                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
236         val offerCallback3 = TestableNetworkOfferCallback(
237                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
238         val offerCallback4 = TestableNetworkOfferCallback(
239                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
240         provider.registerNetworkOffer(scoreWeaker, ncFilter1,
241                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback1)
242         provider.registerNetworkOffer(scoreStronger, ncFilter2,
243                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback2)
244         provider.registerNetworkOffer(scoreWeaker, ncFilter3,
245                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback3)
246         provider2.registerNetworkOffer(scoreWeaker, ncFilter4,
247                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback4)
248         // Unlike Android R, Android S+ provider will only receive interested requests via offer
249         // callback. Verify that the callbacks do not see any existing request such as default
250         // requests.
251         offerCallback1.assertNoCallback()
252         offerCallback2.assertNoCallback()
253         offerCallback3.assertNoCallback()
254         offerCallback4.assertNoCallback()
255 
256         // File a request with specifier but without NOT_VCN, verify network is needed for callback
257         // with the same specifier.
258         val nrNoNotVcn: NetworkRequest = NetworkRequest.Builder()
259                 .addTransportType(TRANSPORT_TEST)
260                 // Test network is not allowed to be trusted.
261                 .removeCapability(NET_CAPABILITY_TRUSTED)
262                 .setNetworkSpecifier(specifier1)
263                 .build()
264         val cb1 = ConnectivityManager.NetworkCallback()
265         mCm.requestNetwork(nrNoNotVcn, cb1)
266         offerCallback1.expectOnNetworkNeeded(ncFilter1)
267         offerCallback2.expectOnNetworkNeeded(ncFilter2)
268         offerCallback3.assertNoCallback()
269         offerCallback4.assertNoCallback()
270 
271         mCm.unregisterNetworkCallback(cb1)
272         offerCallback1.expectOnNetworkUnneeded(ncFilter1)
273         offerCallback2.expectOnNetworkUnneeded(ncFilter2)
274         offerCallback3.assertNoCallback()
275         offerCallback4.assertNoCallback()
276 
277         // File a request without specifier but with NOT_VCN, verify network is needed for offer
278         // with NOT_VCN.
279         val nrNotVcn: NetworkRequest = NetworkRequest.Builder()
280                 .addTransportType(TRANSPORT_TEST)
281                 .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
282                 // Test network is not allowed to be trusted.
283                 .removeCapability(NET_CAPABILITY_TRUSTED)
284                 .build()
285         val cb2 = ConnectivityManager.NetworkCallback()
286         mCm.requestNetwork(nrNotVcn, cb2)
287         offerCallback1.assertNoCallback()
288         offerCallback2.expectOnNetworkNeeded(ncFilter2)
289         offerCallback3.assertNoCallback()
290         offerCallback4.assertNoCallback()
291 
292         // Upgrade offer 3 & 4 to satisfy previous request and then verify they are also needed.
293         ncFilter3.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
294         provider.registerNetworkOffer(scoreWeaker, ncFilter3,
295                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback3)
296         ncFilter4.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
297         provider2.registerNetworkOffer(scoreWeaker, ncFilter4,
298                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback4)
299         offerCallback1.assertNoCallback()
300         offerCallback2.assertNoCallback()
301         offerCallback3.expectOnNetworkNeeded(ncFilter3)
302         offerCallback4.expectOnNetworkNeeded(ncFilter4)
303 
304         // Connect an agent to fulfill the request, verify offer 4 is not needed since it is not
305         // from currently serving provider nor can beat the current satisfier.
306         val nc = NetworkCapabilities().apply {
307             addTransportType(NetworkCapabilities.TRANSPORT_TEST)
308             removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
309             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED)
310             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
311             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
312             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
313             setNetworkSpecifier(specifier1)
314         }
315         val config = NetworkAgentConfig.Builder().build()
316         val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc,
317                 LinkProperties(), scoreWeaker, config, provider) {}
318         agent.register()
319         agent.markConnected()
320         offerCallback1.assertNoCallback()  // Still unneeded.
321         offerCallback2.assertNoCallback()  // Still needed.
322         offerCallback3.assertNoCallback()  // Still needed.
323         offerCallback4.expectOnNetworkUnneeded(ncFilter4)
324 
325         // Upgrade the agent, verify no change since the framework will treat the offer as needed
326         // if a request is currently satisfied by the network provided by the same provider.
327         // TODO: Consider offers with weaker score are unneeded.
328         agent.sendNetworkScore(scoreStronger)
329         offerCallback1.assertNoCallback()  // Still unneeded.
330         offerCallback2.assertNoCallback()  // Still needed.
331         offerCallback3.assertNoCallback()  // Still needed.
332         offerCallback4.assertNoCallback()  // Still unneeded.
333 
334         // Verify that offer callbacks cannot receive any event if offer is unregistered.
335         provider2.unregisterNetworkOffer(offerCallback4)
336         agent.unregister()
337         offerCallback1.assertNoCallback()  // Still unneeded.
338         offerCallback2.assertNoCallback()  // Still needed.
339         offerCallback3.assertNoCallback()  // Still needed.
340         // Since the agent is unregistered, and the offer has chance to satisfy the request,
341         // this callback should receive needed if it is not unregistered.
342         offerCallback4.assertNoCallback()
343 
344         // Verify that offer callbacks cannot receive any event if provider is unregistered.
345         mCm.unregisterNetworkProvider(provider)
346         mCm.unregisterNetworkCallback(cb2)
347         offerCallback1.assertNoCallback()  // No callback since it is still unneeded.
348         offerCallback2.assertNoCallback()  // Should be unneeded if not unregistered.
349         offerCallback3.assertNoCallback()  // Should be unneeded if not unregistered.
350         offerCallback4.assertNoCallback()  // Already unregistered.
351 
352         // Clean up and Verify providers did not receive any callback during the entire test.
353         mCm.unregisterNetworkProvider(provider2)
354         provider.assertNoCallback()
355         provider2.assertNoCallback()
356     }
357 
358     private class TestNetworkCallback : ConnectivityManager.NetworkCallback() {
359         private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()
360         sealed class CallbackEntry {
361             object OnUnavailable : CallbackEntry()
362         }
363 
364         override fun onUnavailable() {
365             seenEvents.add(OnUnavailable)
366         }
367 
368         inline fun <reified T : CallbackEntry> expectCallback(
369             crossinline predicate: (T) -> Boolean
370         ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
371     }
372 
373     @Test
374     fun testDeclareNetworkRequestUnfulfillable() {
375         val mockContext = mock(Context::class.java)
376         doReturn(mCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
377         val provider = createNetworkProvider(mockContext)
378         // ConnectivityManager not required at creation time after R
379         if (isAtLeastS()) {
380             verifyNoMoreInteractions(mockContext)
381         }
382 
383         mCm.registerNetworkProvider(provider)
384 
385         val specifier = CompatUtil.makeTestNetworkSpecifier(
386                 UUID.randomUUID().toString())
387         val nr: NetworkRequest = NetworkRequest.Builder()
388                 .addTransportType(TRANSPORT_TEST)
389                 .setNetworkSpecifier(specifier)
390                 .build()
391 
392         val cb = TestNetworkCallback()
393         mCm.requestNetwork(nr, cb)
394         provider.declareNetworkRequestUnfulfillable(nr)
395         cb.expectCallback<OnUnavailable>() { nr.getNetworkSpecifier() == specifier }
396         mCm.unregisterNetworkProvider(provider)
397     }
398 }
399