1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 import static junit.framework.Assert.assertEquals;
20 
21 import static org.junit.Assert.assertNotNull;
22 import static org.junit.Assert.assertThrows;
23 import static org.junit.Assert.assertTrue;
24 
25 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
26 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback;
27 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
28 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
29 import android.os.Bundle;
30 import android.os.OutcomeReceiver;
31 import android.os.RemoteException;
32 
33 import androidx.test.ext.junit.runners.AndroidJUnit4;
34 import androidx.test.filters.SmallTest;
35 
36 import com.google.common.util.concurrent.MoreExecutors;
37 
38 import org.junit.Before;
39 import org.junit.Test;
40 import org.junit.runner.RunWith;
41 
42 import java.util.HashMap;
43 import java.util.concurrent.CountDownLatch;
44 
45 @SmallTest
46 @RunWith(AndroidJUnit4.class)
47 public class ModelManagerTest {
48     ModelManager mModelManager =
49             new ModelManager(
50                     IDataAccessService.Stub.asInterface(new TestDataAccessService()),
51                     IIsolatedModelService.Stub.asInterface(new TestIsolatedModelService()));
52 
53     private static final String INVALID_MODEL_KEY = "invalid_key";
54     private static final String MODEL_KEY = "model_key";
55     private static final String MISSING_OUTPUT_KEY = "missing-output-key";
56     private boolean mRunInferenceCalled = false;
57     private RemoteDataImpl mRemoteData;
58 
59     @Before
setup()60     public void setup() {
61         mRemoteData =
62                 new RemoteDataImpl(
63                         IDataAccessService.Stub.asInterface(new TestDataAccessService()));
64     }
65 
66     @Test
runInference_success()67     public void runInference_success() throws Exception {
68         HashMap<Integer, Object> outputData = new HashMap<>();
69         outputData.put(0, new float[1]);
70         Object[] input = new Object[1];
71         input[0] = new float[] {1.2f};
72         InferenceInput inferenceContext =
73                 new InferenceInput.Builder(
74                                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(),
75                                 input,
76                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
77                         .build();
78 
79         var callback = new MyTestCallback();
80         mModelManager.run(inferenceContext, MoreExecutors.directExecutor(), callback);
81 
82         callback.mLatch.await();
83         assertTrue(mRunInferenceCalled);
84         assertNotNull(callback.mInferenceOutput);
85         float[] value = (float[]) callback.mInferenceOutput.getDataOutputs().get(0);
86         assertEquals(value[0], 5.0f, 0.01f);
87     }
88 
89     @Test
runInference_error()90     public void runInference_error() throws Exception {
91         HashMap<Integer, Object> outputData = new HashMap<>();
92         outputData.put(0, new float[1]);
93         Object[] input = new Object[1];
94         input[0] = new float[] {1.2f};
95         InferenceInput inferenceContext =
96                 new InferenceInput.Builder(
97                                 new InferenceInput.Params.Builder(mRemoteData, INVALID_MODEL_KEY)
98                                         .build(),
99                                 input,
100                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
101                         .build();
102 
103         var callback = new MyTestCallback();
104         mModelManager.run(inferenceContext, MoreExecutors.directExecutor(), callback);
105 
106         callback.mLatch.await();
107         assertTrue(callback.mError);
108     }
109 
110     @Test
runInference_contextNull_throw()111     public void runInference_contextNull_throw() {
112         assertThrows(
113                 NullPointerException.class,
114                 () ->
115                         mModelManager.run(
116                                 null, MoreExecutors.directExecutor(), new MyTestCallback()));
117     }
118 
119     @Test
runInference_resultMissingInferenceOutput()120     public void runInference_resultMissingInferenceOutput() throws Exception {
121         HashMap<Integer, Object> outputData = new HashMap<>();
122         outputData.put(0, new float[1]);
123         Object[] inputData = new Object[1];
124         inputData[0] = new float[] {1.2f};
125         InferenceInput inferenceContext =
126                 new InferenceInput.Builder(
127                                 new InferenceInput.Params.Builder(mRemoteData, MISSING_OUTPUT_KEY)
128                                         .build(),
129                                 inputData,
130                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
131                         .build();
132 
133         var callback = new MyTestCallback();
134         mModelManager.run(inferenceContext, MoreExecutors.directExecutor(), callback);
135 
136         callback.mLatch.await();
137         assertTrue(callback.mError);
138     }
139 
140     public class MyTestCallback implements OutcomeReceiver<InferenceOutput, Exception> {
141         public boolean mError = false;
142         public InferenceOutput mInferenceOutput = null;
143         private final CountDownLatch mLatch = new CountDownLatch(1);
144 
145         @Override
onResult(InferenceOutput result)146         public void onResult(InferenceOutput result) {
147             mInferenceOutput = result;
148             mLatch.countDown();
149         }
150 
151         @Override
onError(Exception error)152         public void onError(Exception error) {
153             mError = true;
154             mLatch.countDown();
155         }
156     }
157 
158     class TestIsolatedModelService extends IIsolatedModelService.Stub {
159         @Override
runInference(Bundle params, IIsolatedModelServiceCallback callback)160         public void runInference(Bundle params, IIsolatedModelServiceCallback callback)
161                 throws RemoteException {
162             mRunInferenceCalled = true;
163             InferenceInputParcel inputParcel =
164                     params.getParcelable(
165                             Constants.EXTRA_INFERENCE_INPUT, InferenceInputParcel.class);
166             if (inputParcel.getModelId().getKey().equals(INVALID_MODEL_KEY)) {
167                 callback.onError(Constants.STATUS_INTERNAL_ERROR);
168                 return;
169             }
170             if (inputParcel.getModelId().getKey().equals(MISSING_OUTPUT_KEY)) {
171                 callback.onSuccess(new Bundle());
172                 return;
173             }
174             HashMap<Integer, Object> result = new HashMap<>();
175             result.put(0, new float[] {5.0f});
176             Bundle bundle = new Bundle();
177             bundle.putParcelable(
178                     Constants.EXTRA_RESULT,
179                     new InferenceOutputParcel(
180                             new InferenceOutput.Builder().setDataOutputs(result).build()));
181             callback.onSuccess(bundle);
182         }
183     }
184 
185     static class TestDataAccessService extends IDataAccessService.Stub {
186         @Override
onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)187         public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {}
188         @Override
logApiCallStats(int apiName, long latencyMillis, int responseCode)189         public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {}
190     }
191 }
192