1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.ondevicepersonalization.services.inference;
18 
19 import android.adservices.ondevicepersonalization.Constants;
20 import android.adservices.ondevicepersonalization.InferenceInputParcel;
21 import android.adservices.ondevicepersonalization.InferenceOutput;
22 import android.adservices.ondevicepersonalization.InferenceOutputParcel;
23 import android.adservices.ondevicepersonalization.ModelId;
24 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
25 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback;
26 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
27 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
28 import android.annotation.NonNull;
29 import android.os.Bundle;
30 import android.os.ParcelFileDescriptor;
31 import android.os.RemoteException;
32 import android.os.Trace;
33 
34 import com.android.internal.annotations.VisibleForTesting;
35 import com.android.ondevicepersonalization.internal.util.LoggerFactory;
36 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors;
37 import com.android.ondevicepersonalization.services.util.IoUtils;
38 
39 import com.google.common.util.concurrent.ListeningExecutorService;
40 
41 import org.tensorflow.lite.InterpreterApi;
42 import org.tensorflow.lite.Tensor;
43 
44 import java.io.ByteArrayInputStream;
45 import java.io.IOException;
46 import java.io.ObjectInputStream;
47 import java.nio.ByteBuffer;
48 import java.util.List;
49 import java.util.Map;
50 import java.util.Objects;
51 import java.util.concurrent.ArrayBlockingQueue;
52 import java.util.concurrent.BlockingQueue;
53 
54 /** The implementation of {@link IsolatedModelService}. */
55 public class IsolatedModelServiceImpl extends IIsolatedModelService.Stub {
56     private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
57     private static final String TAG = IsolatedModelServiceImpl.class.getSimpleName();
58     @NonNull private final Injector mInjector;
59 
60     static {
61         System.loadLibrary("fcp_cpp_dep_jni");
62     }
63 
64     @VisibleForTesting
IsolatedModelServiceImpl(@onNull Injector injector)65     public IsolatedModelServiceImpl(@NonNull Injector injector) {
66         this.mInjector = injector;
67     }
68 
IsolatedModelServiceImpl()69     public IsolatedModelServiceImpl() {
70         this(new Injector());
71     }
72 
73     @Override
runInference(Bundle params, IIsolatedModelServiceCallback callback)74     public void runInference(Bundle params, IIsolatedModelServiceCallback callback) {
75         InferenceInputParcel inputParcel =
76                 Objects.requireNonNull(
77                         params.getParcelable(
78                                 Constants.EXTRA_INFERENCE_INPUT, InferenceInputParcel.class));
79         IDataAccessService binder =
80                 IDataAccessService.Stub.asInterface(
81                         Objects.requireNonNull(
82                                 params.getBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER)));
83         InferenceOutputParcel outputParcel =
84                 Objects.requireNonNull(inputParcel.getExpectedOutputStructure());
85         mInjector
86                 .getExecutor()
87                 .execute(() -> runTfliteInterpreter(inputParcel, outputParcel, binder, callback));
88     }
89 
runTfliteInterpreter( InferenceInputParcel inputParcel, InferenceOutputParcel outputParcel, IDataAccessService binder, IIsolatedModelServiceCallback callback)90     private void runTfliteInterpreter(
91             InferenceInputParcel inputParcel,
92             InferenceOutputParcel outputParcel,
93             IDataAccessService binder,
94             IIsolatedModelServiceCallback callback) {
95         try {
96             Trace.beginSection("IsolatedModelService#RunInference");
97             Object[] inputs = convertToObjArray(inputParcel.getInputData().getList());
98             if (inputs.length == 0) {
99                 sendError(callback);
100             }
101 
102             ModelId modelId = inputParcel.getModelId();
103             ParcelFileDescriptor modelFd = fetchModel(binder, modelId);
104             ByteBuffer byteBuffer = IoUtils.getByteBufferFromFd(modelFd);
105             if (byteBuffer == null) {
106                 closeFd(modelFd);
107                 sendError(callback);
108             }
109             InterpreterApi interpreter =
110                     InterpreterApi.create(
111                             byteBuffer,
112                             new InterpreterApi.Options()
113                                     .setNumThreads(inputParcel.getCpuNumThread()));
114             Map<Integer, Object> outputs = outputParcel.getData();
115             if (outputs.isEmpty() || inputs.length == 0) {
116                 closeFd(modelFd);
117                 sendError(callback);
118             }
119 
120             // TODO(b/323469981): handle batch size better. Currently TFLite will throws error if
121             // batchSize doesn't match input data size.
122             int batchSize = inputParcel.getBatchSize();
123             for (int i = 0; i < interpreter.getInputTensorCount(); i++) {
124                 Tensor tensor = interpreter.getInputTensor(i);
125                 int[] shape = tensor.shape();
126                 shape[0] = batchSize;
127                 interpreter.resizeInput(i, shape);
128             }
129             interpreter.runForMultipleInputsOutputs(inputs, outputs);
130             interpreter.close();
131 
132             closeFd(modelFd);
133             Bundle bundle = new Bundle();
134             InferenceOutput result = new InferenceOutput.Builder().setDataOutputs(outputs).build();
135             bundle.putParcelable(Constants.EXTRA_RESULT, new InferenceOutputParcel(result));
136             sendResult(bundle, callback);
137             Trace.endSection();
138         } catch (Exception e) {
139             // Catch all exceptions including TFLite errors.
140             sLogger.e(e, TAG + ": Failed to run inference job.");
141             sendError(callback);
142         }
143     }
144 
convertToObjArray(List<byte[]> input)145     private Object[] convertToObjArray(List<byte[]> input) {
146         Object[] output = new Object[input.size()];
147         for (int i = 0; i < input.size(); i++) {
148             ByteArrayInputStream bais = new ByteArrayInputStream(input.get(i));
149             try {
150                 ObjectInputStream ois = new ObjectInputStream(bais);
151                 output[i] = ois.readObject();
152             } catch (IOException | ClassNotFoundException e) {
153                 throw new RuntimeException(e);
154             }
155         }
156         return output;
157     }
158 
closeFd(ParcelFileDescriptor fd)159     private void closeFd(ParcelFileDescriptor fd) {
160         try {
161             fd.close();
162         } catch (IOException e) {
163             sLogger.e(e, TAG + ": Failed to close model file descriptor");
164         }
165     }
166 
fetchModel(IDataAccessService dataAccessService, ModelId modelId)167     private ParcelFileDescriptor fetchModel(IDataAccessService dataAccessService, ModelId modelId) {
168         try {
169             sLogger.d(TAG + ": Start fetch model %s %d", modelId.getKey(), modelId.getTableId());
170             BlockingQueue<Bundle> asyncResult = new ArrayBlockingQueue<>(1);
171             Bundle params = new Bundle();
172             params.putParcelable(Constants.EXTRA_MODEL_ID, modelId);
173             dataAccessService.onRequest(
174                     Constants.DATA_ACCESS_OP_GET_MODEL,
175                     params,
176                     new IDataAccessServiceCallback.Stub() {
177                         @Override
178                         public void onSuccess(Bundle result) {
179                             if (result != null) {
180                                 asyncResult.add(result);
181                             } else {
182                                 asyncResult.add(Bundle.EMPTY);
183                             }
184                         }
185 
186                         @Override
187                         public void onError(int errorCode) {
188                             asyncResult.add(Bundle.EMPTY);
189                         }
190                     });
191             Bundle result = asyncResult.take();
192             ParcelFileDescriptor modelFd =
193                     result.getParcelable(Constants.EXTRA_RESULT, ParcelFileDescriptor.class);
194             Objects.requireNonNull(modelFd);
195             return modelFd;
196         } catch (InterruptedException | RemoteException e) {
197             sLogger.e(TAG + ": Failed to fetch model from DataAccessService", e);
198             throw new IllegalStateException(e);
199         }
200     }
201 
sendError(@onNull IIsolatedModelServiceCallback callback)202     private void sendError(@NonNull IIsolatedModelServiceCallback callback) {
203         try {
204             callback.onError(Constants.STATUS_INTERNAL_ERROR);
205         } catch (RemoteException e) {
206             sLogger.e(TAG + ": Callback error", e);
207         }
208     }
209 
sendResult( @onNull Bundle result, @NonNull IIsolatedModelServiceCallback callback)210     private void sendResult(
211             @NonNull Bundle result, @NonNull IIsolatedModelServiceCallback callback) {
212         try {
213             callback.onSuccess(result);
214         } catch (RemoteException e) {
215             sLogger.e(e, TAG + ": Callback error");
216         }
217     }
218 
219     @VisibleForTesting
220     static class Injector {
getExecutor()221         ListeningExecutorService getExecutor() {
222             return OnDevicePersonalizationExecutors.getBackgroundExecutor();
223         }
224     }
225 }
226