1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.federatedcompute; 18 19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; 20 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; 21 import static android.federatedcompute.common.TrainingInterval.SCHEDULING_MODE_ONE_TIME; 22 23 import static com.google.common.truth.Truth.assertThat; 24 25 import static org.junit.Assert.assertTrue; 26 27 import android.federatedcompute.aidl.IFederatedComputeCallback; 28 import android.federatedcompute.aidl.IResultHandlingService; 29 import android.federatedcompute.common.ClientConstants; 30 import android.federatedcompute.common.ExampleConsumption; 31 import android.federatedcompute.common.TrainingInterval; 32 import android.federatedcompute.common.TrainingOptions; 33 import android.os.Bundle; 34 35 import androidx.test.ext.junit.runners.AndroidJUnit4; 36 37 import org.junit.Before; 38 import org.junit.Test; 39 import org.junit.runner.RunWith; 40 41 import java.util.ArrayList; 42 import java.util.concurrent.CountDownLatch; 43 import java.util.function.Consumer; 44 45 @RunWith(AndroidJUnit4.class) 46 public final class ResultHandlingServiceTest { 47 private static final String TASK_NAME = "task-name"; 48 private static final String TEST_POPULATION = "testPopulation"; 49 private static final int JOB_ID = 12345; 50 private static final byte[] SELECTION_CRITERIA = new byte[] {10, 0, 1}; 51 private static final TrainingOptions TRAINING_OPTIONS = 52 new TrainingOptions.Builder() 53 .setPopulationName(TEST_POPULATION) 54 .setTrainingInterval( 55 new TrainingInterval.Builder() 56 .setSchedulingMode(SCHEDULING_MODE_ONE_TIME) 57 .build()) 58 .build(); 59 private static final ArrayList<ExampleConsumption> EXAMPLE_CONSUMPTIONS = 60 createExampleConsumptionList(); 61 62 private boolean mSuccess = false; 63 private boolean mHandleResultCalled = false; 64 private int mErrorCode = 0; 65 private final CountDownLatch mLatch = new CountDownLatch(1); 66 67 private IResultHandlingService mBinder; 68 private final TestResultHandlingService mTestResultHandlingService = 69 new TestResultHandlingService(); 70 71 @Before doBeforeEachTest()72 public void doBeforeEachTest() { 73 mTestResultHandlingService.onCreate(); 74 mBinder = IResultHandlingService.Stub.asInterface(mTestResultHandlingService.onBind(null)); 75 } 76 77 @Test testHandleResult_success()78 public void testHandleResult_success() throws Exception { 79 Bundle input = new Bundle(); 80 input.putString(ClientConstants.EXTRA_TASK_ID, TASK_NAME); 81 input.putString(ClientConstants.EXTRA_POPULATION_NAME, TEST_POPULATION); 82 input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS); 83 input.putParcelableArrayList( 84 ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, EXAMPLE_CONSUMPTIONS); 85 86 mBinder.handleResult(input, new TestFederatedComputeCallback()); 87 88 mLatch.await(); 89 assertTrue(mHandleResultCalled); 90 assertTrue(mSuccess); 91 } 92 93 @Test testHandleResult_failure()94 public void testHandleResult_failure() throws Exception { 95 Bundle input = new Bundle(); 96 input.putString(ClientConstants.EXTRA_TASK_ID, TASK_NAME); 97 input.putString(ClientConstants.EXTRA_POPULATION_NAME, TEST_POPULATION); 98 input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS); 99 input.putParcelableArrayList(ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, null); 100 101 mBinder.handleResult(input, new TestFederatedComputeCallback()); 102 103 mLatch.await(); 104 assertTrue(mHandleResultCalled); 105 assertThat(mErrorCode).isEqualTo(STATUS_INTERNAL_ERROR); 106 } 107 108 class TestResultHandlingService extends ResultHandlingService { 109 @Override handleResult(Bundle input, Consumer<Integer> callback)110 public void handleResult(Bundle input, Consumer<Integer> callback) { 111 mHandleResultCalled = true; 112 ArrayList<ExampleConsumption> exampleConsumptionList = 113 input.getParcelableArrayList( 114 ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, 115 ExampleConsumption.class); 116 if (exampleConsumptionList == null || exampleConsumptionList.isEmpty()) { 117 callback.accept(STATUS_INTERNAL_ERROR); 118 return; 119 } 120 callback.accept(STATUS_SUCCESS); 121 } 122 } 123 createExampleConsumptionList()124 private static ArrayList<ExampleConsumption> createExampleConsumptionList() { 125 ArrayList<ExampleConsumption> exampleList = new ArrayList<>(); 126 exampleList.add( 127 new ExampleConsumption.Builder() 128 .setTaskId("taskName") 129 .setExampleCount(100) 130 .setSelectionCriteria(SELECTION_CRITERIA) 131 .build()); 132 return exampleList; 133 } 134 135 class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub { 136 @Override onSuccess()137 public void onSuccess() { 138 mSuccess = true; 139 mLatch.countDown(); 140 } 141 142 @Override onFailure(int errorCode)143 public void onFailure(int errorCode) { 144 mErrorCode = errorCode; 145 mLatch.countDown(); 146 } 147 } 148 } 149