1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.federatedcompute;
17 
18 import static org.junit.Assert.assertThrows;
19 import static org.mockito.ArgumentMatchers.any;
20 import static org.mockito.ArgumentMatchers.anyInt;
21 import static org.mockito.ArgumentMatchers.isNull;
22 import static org.mockito.Mockito.doAnswer;
23 import static org.mockito.Mockito.doThrow;
24 import static org.mockito.Mockito.spy;
25 import static org.mockito.Mockito.times;
26 import static org.mockito.Mockito.verify;
27 import static org.mockito.Mockito.when;
28 
29 import android.content.ComponentName;
30 import android.content.Context;
31 import android.content.ContextWrapper;
32 import android.content.Intent;
33 import android.content.ServiceConnection;
34 import android.content.pm.PackageManager;
35 import android.content.pm.ResolveInfo;
36 import android.content.pm.ServiceInfo;
37 import android.federatedcompute.aidl.IFederatedComputeCallback;
38 import android.federatedcompute.aidl.IFederatedComputeService;
39 import android.federatedcompute.common.ScheduleFederatedComputeRequest;
40 import android.federatedcompute.common.TrainingOptions;
41 import android.os.IBinder;
42 import android.os.OutcomeReceiver;
43 import android.os.RemoteException;
44 
45 import androidx.test.core.app.ApplicationProvider;
46 
47 import org.junit.Before;
48 import org.junit.Test;
49 import org.junit.runner.RunWith;
50 import org.junit.runners.Parameterized;
51 import org.mockito.Mock;
52 import org.mockito.MockitoAnnotations;
53 
54 import java.util.Arrays;
55 import java.util.Collection;
56 import java.util.List;
57 import java.util.concurrent.Executor;
58 import java.util.concurrent.Executors;
59 
60 @RunWith(Parameterized.class)
61 public class FederatedComputeManagerTest {
62 
63     private final Context mContext =
64             spy(new MyTestContext(ApplicationProvider.getApplicationContext()));
65 
66     private static final ComponentName OWNER_COMPONENT =
67             ComponentName.createRelative("com.android.package.name", "com.android.class.name");
68 
69     @Parameterized.Parameter(0)
70     public String scenario;
71 
72     @Parameterized.Parameter(1)
73     public ScheduleFederatedComputeRequest request;
74 
75     @Parameterized.Parameter(2)
76     public String populationName;
77 
78     @Parameterized.Parameter(3)
79     public IFederatedComputeService iFederatedComputeService;
80 
81     @Mock private PackageManager mMockPackageManager;
82     @Mock private IBinder mMockIBinder;
83     @Mock private IFederatedComputeService mMockIService;
84 
85     @Parameterized.Parameters
data()86     public static Collection<Object[]> data() {
87         return Arrays.asList(
88                 new Object[][] {
89                         {"schedule-allNull", null, null, null},
90                         {
91                                 "schedule-default-iService",
92                                 new ScheduleFederatedComputeRequest.Builder()
93                                         .setTrainingOptions(new TrainingOptions.Builder().build())
94                                         .build(),
95                                 null,
96                                 new IFederatedComputeService.Default()
97                         },
98                         {
99                                 "schedule-mockIService-RemoteException",
100                                 new ScheduleFederatedComputeRequest.Builder()
101                                         .setTrainingOptions(new TrainingOptions.Builder().build())
102                                         .build(),
103                                 null,
104                                 null /* mock will be returned */
105                         },
106                         {
107                                 "schedule-mockIService-onSuccess",
108                                 new ScheduleFederatedComputeRequest.Builder()
109                                         .setTrainingOptions(new TrainingOptions.Builder().build())
110                                         .build(),
111                                 null,
112                                 null /* mock will be returned */
113                         },
114                         {
115                                 "schedule-mockIService-onFailure",
116                                 new ScheduleFederatedComputeRequest.Builder()
117                                         .setTrainingOptions(new TrainingOptions.Builder().build())
118                                         .build(),
119                                 null,
120                                 null /* mock will be returned */
121                         },
122                         {"cancel-allNull", null, null, null},
123                         {
124                                 "cancel-default-iService",
125                                 null,
126                                 "testPopulation",
127                                 new IFederatedComputeService.Default()
128                         },
129                         {
130                                 "cancel-mockIService-RemoteException",
131                                 null,
132                                 "testPopulation",
133                                 null /* mock will be returned */
134                         },
135                         {
136                                 "cancel-mockIService-onSuccess",
137                                 null,
138                                 "testPopulation",
139                                 null /* mock will be returned */
140                         },
141                         {
142                                 "cancel-mockIService-onFailure",
143                                 null,
144                                 "testPopulation",
145                                 null /* mock will be returned */
146                         },
147                 });
148     }
149 
150     @Before
setUp()151     public void setUp() {
152         MockitoAnnotations.initMocks(this);
153         ResolveInfo resolveInfo = new ResolveInfo();
154         ServiceInfo serviceInfo = new ServiceInfo();
155         serviceInfo.name = "TestName";
156         serviceInfo.packageName = "com.android.federatedcompute.services";
157         resolveInfo.serviceInfo = serviceInfo;
158         when(mMockPackageManager.queryIntentServices(any(), anyInt()))
159                 .thenReturn(List.of(resolveInfo));
160         when(mMockIBinder.queryLocalInterface(any())).thenReturn(iFederatedComputeService);
161     }
162 
163     @Test
testScheduleFederatedCompute()164     public void testScheduleFederatedCompute() throws RemoteException {
165         FederatedComputeManager manager = new FederatedComputeManager(mContext);
166         OutcomeReceiver<Object, Exception> spyCallback;
167 
168         switch (scenario) {
169             case "schedule-allNull":
170                 assertThrows(
171                         NullPointerException.class, () -> manager.schedule(request, null, null));
172                 break;
173             case "schedule-default-iService":
174                 manager.schedule(request, Executors.newSingleThreadExecutor(), null);
175                 break;
176             case "schedule-mockIService-RemoteException":
177                 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService);
178                 doThrow(new RemoteException()).when(mMockIService).schedule(any(), any(), any());
179                 spyCallback = spy(new MyTestCallback());
180 
181                 manager.schedule(request, Runnable::run, spyCallback);
182 
183                 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any());
184                 verify(spyCallback, times(1)).onError(any(RemoteException.class));
185                 verify(mContext, times(1)).unbindService(any());
186                 break;
187             case "schedule-mockIService-onSuccess":
188                 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService);
189                 doAnswer(
190                         invocation -> {
191                             IFederatedComputeCallback federatedComputeCallback =
192                                     invocation.getArgument(2);
193                             federatedComputeCallback.onSuccess();
194                             return null;
195                         })
196                         .when(mMockIService)
197                         .schedule(any(), any(), any());
198                 spyCallback = spy(new MyTestCallback());
199 
200                 manager.schedule(request, Runnable::run, spyCallback);
201 
202                 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any());
203                 verify(spyCallback, times(1)).onResult(isNull());
204                 verify(mContext, times(1)).unbindService(any());
205                 break;
206             case "schedule-mockIService-onFailure":
207                 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService);
208                 doAnswer(
209                         invocation -> {
210                             IFederatedComputeCallback federatedComputeCallback =
211                                     invocation.getArgument(2);
212                             federatedComputeCallback.onFailure(1);
213                             return null;
214                         })
215                         .when(mMockIService)
216                         .schedule(any(), any(), any());
217                 spyCallback = spy(new MyTestCallback());
218 
219                 manager.schedule(request, Runnable::run, spyCallback);
220 
221                 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any());
222                 verify(spyCallback, times(1)).onError(any(FederatedComputeException.class));
223                 verify(mContext, times(1)).unbindService(any());
224                 break;
225             case "cancel-allNull":
226                 assertThrows(
227                         NullPointerException.class,
228                         () ->
229                                 manager.cancel(
230                                         OWNER_COMPONENT,
231                                         populationName,
232                                         null,
233                                         null));
234                 break;
235             case "cancel-default-iService":
236                 manager.cancel(
237                         OWNER_COMPONENT,
238                         populationName,
239                         Executors.newSingleThreadExecutor(),
240                         null);
241                 break;
242             case "cancel-mockIService-RemoteException":
243                 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService);
244                 doThrow(new RemoteException())
245                         .when(mMockIService)
246                         .cancel(any(), any(), any());
247                 spyCallback = spy(new MyTestCallback());
248 
249                 manager.cancel(
250                         OWNER_COMPONENT,
251                         populationName,
252                         Runnable::run,
253                         spyCallback);
254 
255                 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any());
256                 verify(spyCallback, times(1)).onError(any(RemoteException.class));
257                 verify(mContext, times(1)).unbindService(any());
258                 break;
259             case "cancel-mockIService-onSuccess":
260                 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService);
261                 doAnswer(
262                         invocation -> {
263                             IFederatedComputeCallback federatedComputeCallback =
264                                     invocation.getArgument(2);
265                             federatedComputeCallback.onSuccess();
266                             return null;
267                         })
268                         .when(mMockIService)
269                         .cancel(any(), any(), any());
270                 spyCallback = spy(new MyTestCallback());
271 
272                 manager.cancel(
273                         OWNER_COMPONENT,
274                         populationName,
275                         Runnable::run,
276                         spyCallback);
277 
278                 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any());
279                 verify(spyCallback, times(1)).onResult(isNull());
280                 verify(mContext, times(1)).unbindService(any());
281                 break;
282             case "cancel-mockIService-onFailure":
283                 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService);
284                 doAnswer(
285                         invocation -> {
286                             IFederatedComputeCallback federatedComputeCallback =
287                                     invocation.getArgument(2);
288                             federatedComputeCallback.onFailure(1);
289                             return null;
290                         })
291                         .when(mMockIService)
292                         .cancel(any(), any(), any());
293                 spyCallback = spy(new MyTestCallback());
294 
295                 manager.cancel(
296                         OWNER_COMPONENT,
297                         populationName,
298                         Runnable::run,
299                         spyCallback);
300 
301                 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any());
302                 verify(spyCallback, times(1)).onError(any(FederatedComputeException.class));
303                 verify(mContext, times(1)).unbindService(any());
304                 break;
305             default:
306                 break;
307         }
308     }
309 
310     public class MyTestContext extends ContextWrapper {
311 
MyTestContext(Context context)312         MyTestContext(Context context) {
313             super(context);
314         }
315 
316         @Override
getPackageManager()317         public PackageManager getPackageManager() {
318             return mMockPackageManager != null ? mMockPackageManager : super.getPackageManager();
319         }
320 
321         @Override
bindService( Intent service, int flags, Executor executor, ServiceConnection conn)322         public boolean bindService(
323                 Intent service, int flags, Executor executor, ServiceConnection conn) {
324             executor.execute(
325                     () -> {
326                         conn.onServiceConnected(null, mMockIBinder);
327                     });
328             return true;
329         }
330 
unbindService(ServiceConnection conn)331         public void unbindService(ServiceConnection conn) {}
332     }
333 
334     public class MyTestCallback implements OutcomeReceiver<Object, Exception> {
335 
336         @Override
onResult(Object o)337         public void onResult(Object o) {}
338 
339         @Override
onError(Exception error)340         public void onError(Exception error) {
341             OutcomeReceiver.super.onError(error);
342         }
343     }
344 }
345