1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.federatedcompute;
18 
19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR;
20 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS;
21 import static android.federatedcompute.common.TrainingInterval.SCHEDULING_MODE_ONE_TIME;
22 
23 import static com.google.common.truth.Truth.assertThat;
24 
25 import static org.junit.Assert.assertTrue;
26 
27 import android.federatedcompute.aidl.IFederatedComputeCallback;
28 import android.federatedcompute.aidl.IResultHandlingService;
29 import android.federatedcompute.common.ClientConstants;
30 import android.federatedcompute.common.ExampleConsumption;
31 import android.federatedcompute.common.TrainingInterval;
32 import android.federatedcompute.common.TrainingOptions;
33 import android.os.Bundle;
34 
35 import androidx.test.ext.junit.runners.AndroidJUnit4;
36 
37 import org.junit.Before;
38 import org.junit.Test;
39 import org.junit.runner.RunWith;
40 
41 import java.util.ArrayList;
42 import java.util.concurrent.CountDownLatch;
43 import java.util.function.Consumer;
44 
45 @RunWith(AndroidJUnit4.class)
46 public final class ResultHandlingServiceTest {
47     private static final String TASK_NAME = "task-name";
48     private static final String TEST_POPULATION = "testPopulation";
49     private static final int JOB_ID = 12345;
50     private static final byte[] SELECTION_CRITERIA = new byte[] {10, 0, 1};
51     private static final TrainingOptions TRAINING_OPTIONS =
52             new TrainingOptions.Builder()
53                     .setPopulationName(TEST_POPULATION)
54                     .setTrainingInterval(
55                             new TrainingInterval.Builder()
56                                     .setSchedulingMode(SCHEDULING_MODE_ONE_TIME)
57                                     .build())
58                     .build();
59     private static final ArrayList<ExampleConsumption> EXAMPLE_CONSUMPTIONS =
60             createExampleConsumptionList();
61 
62     private boolean mSuccess = false;
63     private boolean mHandleResultCalled = false;
64     private int mErrorCode = 0;
65     private final CountDownLatch mLatch = new CountDownLatch(1);
66 
67     private IResultHandlingService mBinder;
68     private final TestResultHandlingService mTestResultHandlingService =
69             new TestResultHandlingService();
70 
71     @Before
doBeforeEachTest()72     public void doBeforeEachTest() {
73         mTestResultHandlingService.onCreate();
74         mBinder = IResultHandlingService.Stub.asInterface(mTestResultHandlingService.onBind(null));
75     }
76 
77     @Test
testHandleResult_success()78     public void testHandleResult_success() throws Exception {
79         Bundle input = new Bundle();
80         input.putString(ClientConstants.EXTRA_TASK_ID, TASK_NAME);
81         input.putString(ClientConstants.EXTRA_POPULATION_NAME, TEST_POPULATION);
82         input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS);
83         input.putParcelableArrayList(
84                 ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, EXAMPLE_CONSUMPTIONS);
85 
86         mBinder.handleResult(input, new TestFederatedComputeCallback());
87 
88         mLatch.await();
89         assertTrue(mHandleResultCalled);
90         assertTrue(mSuccess);
91     }
92 
93     @Test
testHandleResult_failure()94     public void testHandleResult_failure() throws Exception {
95         Bundle input = new Bundle();
96         input.putString(ClientConstants.EXTRA_TASK_ID, TASK_NAME);
97         input.putString(ClientConstants.EXTRA_POPULATION_NAME, TEST_POPULATION);
98         input.putInt(ClientConstants.EXTRA_COMPUTATION_RESULT, STATUS_SUCCESS);
99         input.putParcelableArrayList(ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, null);
100 
101         mBinder.handleResult(input, new TestFederatedComputeCallback());
102 
103         mLatch.await();
104         assertTrue(mHandleResultCalled);
105         assertThat(mErrorCode).isEqualTo(STATUS_INTERNAL_ERROR);
106     }
107 
108     class TestResultHandlingService extends ResultHandlingService {
109         @Override
handleResult(Bundle input, Consumer<Integer> callback)110         public void handleResult(Bundle input, Consumer<Integer> callback) {
111             mHandleResultCalled = true;
112             ArrayList<ExampleConsumption> exampleConsumptionList =
113                     input.getParcelableArrayList(
114                             ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST,
115                             ExampleConsumption.class);
116             if (exampleConsumptionList == null || exampleConsumptionList.isEmpty()) {
117                 callback.accept(STATUS_INTERNAL_ERROR);
118                 return;
119             }
120             callback.accept(STATUS_SUCCESS);
121         }
122     }
123 
createExampleConsumptionList()124     private static ArrayList<ExampleConsumption> createExampleConsumptionList() {
125         ArrayList<ExampleConsumption> exampleList = new ArrayList<>();
126         exampleList.add(
127                 new ExampleConsumption.Builder()
128                         .setTaskId("taskName")
129                         .setExampleCount(100)
130                         .setSelectionCriteria(SELECTION_CRITERIA)
131                         .build());
132         return exampleList;
133     }
134 
135     class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub {
136         @Override
onSuccess()137         public void onSuccess() {
138             mSuccess = true;
139             mLatch.countDown();
140         }
141 
142         @Override
onFailure(int errorCode)143         public void onFailure(int errorCode) {
144             mErrorCode = errorCode;
145             mLatch.countDown();
146         }
147     }
148 }
149