1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.ondevicepersonalization; 18 19 import static junit.framework.Assert.assertEquals; 20 21 import static org.junit.Assert.assertNotNull; 22 import static org.junit.Assert.assertThrows; 23 import static org.junit.Assert.assertTrue; 24 25 import android.adservices.ondevicepersonalization.aidl.IDataAccessService; 26 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; 27 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService; 28 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback; 29 import android.os.Bundle; 30 import android.os.OutcomeReceiver; 31 import android.os.RemoteException; 32 33 import androidx.test.ext.junit.runners.AndroidJUnit4; 34 import androidx.test.filters.SmallTest; 35 36 import com.google.common.util.concurrent.MoreExecutors; 37 38 import org.junit.Before; 39 import org.junit.Test; 40 import org.junit.runner.RunWith; 41 42 import java.util.HashMap; 43 import java.util.concurrent.CountDownLatch; 44 45 @SmallTest 46 @RunWith(AndroidJUnit4.class) 47 public class ModelManagerTest { 48 ModelManager mModelManager = 49 new ModelManager( 50 IDataAccessService.Stub.asInterface(new TestDataAccessService()), 51 IIsolatedModelService.Stub.asInterface(new TestIsolatedModelService())); 52 53 private static final String INVALID_MODEL_KEY = "invalid_key"; 54 private static final String MODEL_KEY = "model_key"; 55 private static final String MISSING_OUTPUT_KEY = "missing-output-key"; 56 private boolean mRunInferenceCalled = false; 57 private RemoteDataImpl mRemoteData; 58 59 @Before setup()60 public void setup() { 61 mRemoteData = 62 new RemoteDataImpl( 63 IDataAccessService.Stub.asInterface(new TestDataAccessService())); 64 } 65 66 @Test runInference_success()67 public void runInference_success() throws Exception { 68 HashMap<Integer, Object> outputData = new HashMap<>(); 69 outputData.put(0, new float[1]); 70 Object[] input = new Object[1]; 71 input[0] = new float[] {1.2f}; 72 InferenceInput inferenceContext = 73 new InferenceInput.Builder( 74 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(), 75 input, 76 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 77 .build(); 78 79 var callback = new MyTestCallback(); 80 mModelManager.run(inferenceContext, MoreExecutors.directExecutor(), callback); 81 82 callback.mLatch.await(); 83 assertTrue(mRunInferenceCalled); 84 assertNotNull(callback.mInferenceOutput); 85 float[] value = (float[]) callback.mInferenceOutput.getDataOutputs().get(0); 86 assertEquals(value[0], 5.0f, 0.01f); 87 } 88 89 @Test runInference_error()90 public void runInference_error() throws Exception { 91 HashMap<Integer, Object> outputData = new HashMap<>(); 92 outputData.put(0, new float[1]); 93 Object[] input = new Object[1]; 94 input[0] = new float[] {1.2f}; 95 InferenceInput inferenceContext = 96 new InferenceInput.Builder( 97 new InferenceInput.Params.Builder(mRemoteData, INVALID_MODEL_KEY) 98 .build(), 99 input, 100 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 101 .build(); 102 103 var callback = new MyTestCallback(); 104 mModelManager.run(inferenceContext, MoreExecutors.directExecutor(), callback); 105 106 callback.mLatch.await(); 107 assertTrue(callback.mError); 108 } 109 110 @Test runInference_contextNull_throw()111 public void runInference_contextNull_throw() { 112 assertThrows( 113 NullPointerException.class, 114 () -> 115 mModelManager.run( 116 null, MoreExecutors.directExecutor(), new MyTestCallback())); 117 } 118 119 @Test runInference_resultMissingInferenceOutput()120 public void runInference_resultMissingInferenceOutput() throws Exception { 121 HashMap<Integer, Object> outputData = new HashMap<>(); 122 outputData.put(0, new float[1]); 123 Object[] inputData = new Object[1]; 124 inputData[0] = new float[] {1.2f}; 125 InferenceInput inferenceContext = 126 new InferenceInput.Builder( 127 new InferenceInput.Params.Builder(mRemoteData, MISSING_OUTPUT_KEY) 128 .build(), 129 inputData, 130 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 131 .build(); 132 133 var callback = new MyTestCallback(); 134 mModelManager.run(inferenceContext, MoreExecutors.directExecutor(), callback); 135 136 callback.mLatch.await(); 137 assertTrue(callback.mError); 138 } 139 140 public class MyTestCallback implements OutcomeReceiver<InferenceOutput, Exception> { 141 public boolean mError = false; 142 public InferenceOutput mInferenceOutput = null; 143 private final CountDownLatch mLatch = new CountDownLatch(1); 144 145 @Override onResult(InferenceOutput result)146 public void onResult(InferenceOutput result) { 147 mInferenceOutput = result; 148 mLatch.countDown(); 149 } 150 151 @Override onError(Exception error)152 public void onError(Exception error) { 153 mError = true; 154 mLatch.countDown(); 155 } 156 } 157 158 class TestIsolatedModelService extends IIsolatedModelService.Stub { 159 @Override runInference(Bundle params, IIsolatedModelServiceCallback callback)160 public void runInference(Bundle params, IIsolatedModelServiceCallback callback) 161 throws RemoteException { 162 mRunInferenceCalled = true; 163 InferenceInputParcel inputParcel = 164 params.getParcelable( 165 Constants.EXTRA_INFERENCE_INPUT, InferenceInputParcel.class); 166 if (inputParcel.getModelId().getKey().equals(INVALID_MODEL_KEY)) { 167 callback.onError(Constants.STATUS_INTERNAL_ERROR); 168 return; 169 } 170 if (inputParcel.getModelId().getKey().equals(MISSING_OUTPUT_KEY)) { 171 callback.onSuccess(new Bundle()); 172 return; 173 } 174 HashMap<Integer, Object> result = new HashMap<>(); 175 result.put(0, new float[] {5.0f}); 176 Bundle bundle = new Bundle(); 177 bundle.putParcelable( 178 Constants.EXTRA_RESULT, 179 new InferenceOutputParcel( 180 new InferenceOutput.Builder().setDataOutputs(result).build())); 181 callback.onSuccess(bundle); 182 } 183 } 184 185 static class TestDataAccessService extends IDataAccessService.Stub { 186 @Override onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)187 public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {} 188 @Override logApiCallStats(int apiName, long latencyMillis, int responseCode)189 public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {} 190 } 191 } 192