1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.federatedcompute; 17 18 import static org.junit.Assert.assertThrows; 19 import static org.mockito.ArgumentMatchers.any; 20 import static org.mockito.ArgumentMatchers.anyInt; 21 import static org.mockito.ArgumentMatchers.isNull; 22 import static org.mockito.Mockito.doAnswer; 23 import static org.mockito.Mockito.doThrow; 24 import static org.mockito.Mockito.spy; 25 import static org.mockito.Mockito.times; 26 import static org.mockito.Mockito.verify; 27 import static org.mockito.Mockito.when; 28 29 import android.content.ComponentName; 30 import android.content.Context; 31 import android.content.ContextWrapper; 32 import android.content.Intent; 33 import android.content.ServiceConnection; 34 import android.content.pm.PackageManager; 35 import android.content.pm.ResolveInfo; 36 import android.content.pm.ServiceInfo; 37 import android.federatedcompute.aidl.IFederatedComputeCallback; 38 import android.federatedcompute.aidl.IFederatedComputeService; 39 import android.federatedcompute.common.ScheduleFederatedComputeRequest; 40 import android.federatedcompute.common.TrainingOptions; 41 import android.os.IBinder; 42 import android.os.OutcomeReceiver; 43 import android.os.RemoteException; 44 45 import androidx.test.core.app.ApplicationProvider; 46 47 import org.junit.Before; 48 import org.junit.Test; 49 import org.junit.runner.RunWith; 50 import org.junit.runners.Parameterized; 51 import org.mockito.Mock; 52 import org.mockito.MockitoAnnotations; 53 54 import java.util.Arrays; 55 import java.util.Collection; 56 import java.util.List; 57 import java.util.concurrent.Executor; 58 import java.util.concurrent.Executors; 59 60 @RunWith(Parameterized.class) 61 public class FederatedComputeManagerTest { 62 63 private final Context mContext = 64 spy(new MyTestContext(ApplicationProvider.getApplicationContext())); 65 66 private static final ComponentName OWNER_COMPONENT = 67 ComponentName.createRelative("com.android.package.name", "com.android.class.name"); 68 69 @Parameterized.Parameter(0) 70 public String scenario; 71 72 @Parameterized.Parameter(1) 73 public ScheduleFederatedComputeRequest request; 74 75 @Parameterized.Parameter(2) 76 public String populationName; 77 78 @Parameterized.Parameter(3) 79 public IFederatedComputeService iFederatedComputeService; 80 81 @Mock private PackageManager mMockPackageManager; 82 @Mock private IBinder mMockIBinder; 83 @Mock private IFederatedComputeService mMockIService; 84 85 @Parameterized.Parameters data()86 public static Collection<Object[]> data() { 87 return Arrays.asList( 88 new Object[][] { 89 {"schedule-allNull", null, null, null}, 90 { 91 "schedule-default-iService", 92 new ScheduleFederatedComputeRequest.Builder() 93 .setTrainingOptions(new TrainingOptions.Builder().build()) 94 .build(), 95 null, 96 new IFederatedComputeService.Default() 97 }, 98 { 99 "schedule-mockIService-RemoteException", 100 new ScheduleFederatedComputeRequest.Builder() 101 .setTrainingOptions(new TrainingOptions.Builder().build()) 102 .build(), 103 null, 104 null /* mock will be returned */ 105 }, 106 { 107 "schedule-mockIService-onSuccess", 108 new ScheduleFederatedComputeRequest.Builder() 109 .setTrainingOptions(new TrainingOptions.Builder().build()) 110 .build(), 111 null, 112 null /* mock will be returned */ 113 }, 114 { 115 "schedule-mockIService-onFailure", 116 new ScheduleFederatedComputeRequest.Builder() 117 .setTrainingOptions(new TrainingOptions.Builder().build()) 118 .build(), 119 null, 120 null /* mock will be returned */ 121 }, 122 {"cancel-allNull", null, null, null}, 123 { 124 "cancel-default-iService", 125 null, 126 "testPopulation", 127 new IFederatedComputeService.Default() 128 }, 129 { 130 "cancel-mockIService-RemoteException", 131 null, 132 "testPopulation", 133 null /* mock will be returned */ 134 }, 135 { 136 "cancel-mockIService-onSuccess", 137 null, 138 "testPopulation", 139 null /* mock will be returned */ 140 }, 141 { 142 "cancel-mockIService-onFailure", 143 null, 144 "testPopulation", 145 null /* mock will be returned */ 146 }, 147 }); 148 } 149 150 @Before setUp()151 public void setUp() { 152 MockitoAnnotations.initMocks(this); 153 ResolveInfo resolveInfo = new ResolveInfo(); 154 ServiceInfo serviceInfo = new ServiceInfo(); 155 serviceInfo.name = "TestName"; 156 serviceInfo.packageName = "com.android.federatedcompute.services"; 157 resolveInfo.serviceInfo = serviceInfo; 158 when(mMockPackageManager.queryIntentServices(any(), anyInt())) 159 .thenReturn(List.of(resolveInfo)); 160 when(mMockIBinder.queryLocalInterface(any())).thenReturn(iFederatedComputeService); 161 } 162 163 @Test testScheduleFederatedCompute()164 public void testScheduleFederatedCompute() throws RemoteException { 165 FederatedComputeManager manager = new FederatedComputeManager(mContext); 166 OutcomeReceiver<Object, Exception> spyCallback; 167 168 switch (scenario) { 169 case "schedule-allNull": 170 assertThrows( 171 NullPointerException.class, () -> manager.schedule(request, null, null)); 172 break; 173 case "schedule-default-iService": 174 manager.schedule(request, Executors.newSingleThreadExecutor(), null); 175 break; 176 case "schedule-mockIService-RemoteException": 177 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 178 doThrow(new RemoteException()).when(mMockIService).schedule(any(), any(), any()); 179 spyCallback = spy(new MyTestCallback()); 180 181 manager.schedule(request, Runnable::run, spyCallback); 182 183 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 184 verify(spyCallback, times(1)).onError(any(RemoteException.class)); 185 verify(mContext, times(1)).unbindService(any()); 186 break; 187 case "schedule-mockIService-onSuccess": 188 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 189 doAnswer( 190 invocation -> { 191 IFederatedComputeCallback federatedComputeCallback = 192 invocation.getArgument(2); 193 federatedComputeCallback.onSuccess(); 194 return null; 195 }) 196 .when(mMockIService) 197 .schedule(any(), any(), any()); 198 spyCallback = spy(new MyTestCallback()); 199 200 manager.schedule(request, Runnable::run, spyCallback); 201 202 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 203 verify(spyCallback, times(1)).onResult(isNull()); 204 verify(mContext, times(1)).unbindService(any()); 205 break; 206 case "schedule-mockIService-onFailure": 207 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 208 doAnswer( 209 invocation -> { 210 IFederatedComputeCallback federatedComputeCallback = 211 invocation.getArgument(2); 212 federatedComputeCallback.onFailure(1); 213 return null; 214 }) 215 .when(mMockIService) 216 .schedule(any(), any(), any()); 217 spyCallback = spy(new MyTestCallback()); 218 219 manager.schedule(request, Runnable::run, spyCallback); 220 221 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 222 verify(spyCallback, times(1)).onError(any(FederatedComputeException.class)); 223 verify(mContext, times(1)).unbindService(any()); 224 break; 225 case "cancel-allNull": 226 assertThrows( 227 NullPointerException.class, 228 () -> 229 manager.cancel( 230 OWNER_COMPONENT, 231 populationName, 232 null, 233 null)); 234 break; 235 case "cancel-default-iService": 236 manager.cancel( 237 OWNER_COMPONENT, 238 populationName, 239 Executors.newSingleThreadExecutor(), 240 null); 241 break; 242 case "cancel-mockIService-RemoteException": 243 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 244 doThrow(new RemoteException()) 245 .when(mMockIService) 246 .cancel(any(), any(), any()); 247 spyCallback = spy(new MyTestCallback()); 248 249 manager.cancel( 250 OWNER_COMPONENT, 251 populationName, 252 Runnable::run, 253 spyCallback); 254 255 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 256 verify(spyCallback, times(1)).onError(any(RemoteException.class)); 257 verify(mContext, times(1)).unbindService(any()); 258 break; 259 case "cancel-mockIService-onSuccess": 260 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 261 doAnswer( 262 invocation -> { 263 IFederatedComputeCallback federatedComputeCallback = 264 invocation.getArgument(2); 265 federatedComputeCallback.onSuccess(); 266 return null; 267 }) 268 .when(mMockIService) 269 .cancel(any(), any(), any()); 270 spyCallback = spy(new MyTestCallback()); 271 272 manager.cancel( 273 OWNER_COMPONENT, 274 populationName, 275 Runnable::run, 276 spyCallback); 277 278 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 279 verify(spyCallback, times(1)).onResult(isNull()); 280 verify(mContext, times(1)).unbindService(any()); 281 break; 282 case "cancel-mockIService-onFailure": 283 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 284 doAnswer( 285 invocation -> { 286 IFederatedComputeCallback federatedComputeCallback = 287 invocation.getArgument(2); 288 federatedComputeCallback.onFailure(1); 289 return null; 290 }) 291 .when(mMockIService) 292 .cancel(any(), any(), any()); 293 spyCallback = spy(new MyTestCallback()); 294 295 manager.cancel( 296 OWNER_COMPONENT, 297 populationName, 298 Runnable::run, 299 spyCallback); 300 301 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 302 verify(spyCallback, times(1)).onError(any(FederatedComputeException.class)); 303 verify(mContext, times(1)).unbindService(any()); 304 break; 305 default: 306 break; 307 } 308 } 309 310 public class MyTestContext extends ContextWrapper { 311 MyTestContext(Context context)312 MyTestContext(Context context) { 313 super(context); 314 } 315 316 @Override getPackageManager()317 public PackageManager getPackageManager() { 318 return mMockPackageManager != null ? mMockPackageManager : super.getPackageManager(); 319 } 320 321 @Override bindService( Intent service, int flags, Executor executor, ServiceConnection conn)322 public boolean bindService( 323 Intent service, int flags, Executor executor, ServiceConnection conn) { 324 executor.execute( 325 () -> { 326 conn.onServiceConnected(null, mMockIBinder); 327 }); 328 return true; 329 } 330 unbindService(ServiceConnection conn)331 public void unbindService(ServiceConnection conn) {} 332 } 333 334 public class MyTestCallback implements OutcomeReceiver<Object, Exception> { 335 336 @Override onResult(Object o)337 public void onResult(Object o) {} 338 339 @Override onError(Exception error)340 public void onError(Exception error) { 341 OutcomeReceiver.super.onError(error); 342 } 343 } 344 } 345