1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.ondevicepersonalization.services.inference; 18 19 import android.adservices.ondevicepersonalization.Constants; 20 import android.adservices.ondevicepersonalization.InferenceInputParcel; 21 import android.adservices.ondevicepersonalization.InferenceOutput; 22 import android.adservices.ondevicepersonalization.InferenceOutputParcel; 23 import android.adservices.ondevicepersonalization.ModelId; 24 import android.adservices.ondevicepersonalization.aidl.IDataAccessService; 25 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; 26 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService; 27 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback; 28 import android.annotation.NonNull; 29 import android.os.Bundle; 30 import android.os.ParcelFileDescriptor; 31 import android.os.RemoteException; 32 import android.os.Trace; 33 34 import com.android.internal.annotations.VisibleForTesting; 35 import com.android.ondevicepersonalization.internal.util.LoggerFactory; 36 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; 37 import com.android.ondevicepersonalization.services.util.IoUtils; 38 39 import com.google.common.util.concurrent.ListeningExecutorService; 40 41 import org.tensorflow.lite.InterpreterApi; 42 import org.tensorflow.lite.Tensor; 43 44 import java.io.ByteArrayInputStream; 45 import java.io.IOException; 46 import java.io.ObjectInputStream; 47 import java.nio.ByteBuffer; 48 import java.util.List; 49 import java.util.Map; 50 import java.util.Objects; 51 import java.util.concurrent.ArrayBlockingQueue; 52 import java.util.concurrent.BlockingQueue; 53 54 /** The implementation of {@link IsolatedModelService}. */ 55 public class IsolatedModelServiceImpl extends IIsolatedModelService.Stub { 56 private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); 57 private static final String TAG = IsolatedModelServiceImpl.class.getSimpleName(); 58 @NonNull private final Injector mInjector; 59 60 static { 61 System.loadLibrary("fcp_cpp_dep_jni"); 62 } 63 64 @VisibleForTesting IsolatedModelServiceImpl(@onNull Injector injector)65 public IsolatedModelServiceImpl(@NonNull Injector injector) { 66 this.mInjector = injector; 67 } 68 IsolatedModelServiceImpl()69 public IsolatedModelServiceImpl() { 70 this(new Injector()); 71 } 72 73 @Override runInference(Bundle params, IIsolatedModelServiceCallback callback)74 public void runInference(Bundle params, IIsolatedModelServiceCallback callback) { 75 InferenceInputParcel inputParcel = 76 Objects.requireNonNull( 77 params.getParcelable( 78 Constants.EXTRA_INFERENCE_INPUT, InferenceInputParcel.class)); 79 IDataAccessService binder = 80 IDataAccessService.Stub.asInterface( 81 Objects.requireNonNull( 82 params.getBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER))); 83 InferenceOutputParcel outputParcel = 84 Objects.requireNonNull(inputParcel.getExpectedOutputStructure()); 85 mInjector 86 .getExecutor() 87 .execute(() -> runTfliteInterpreter(inputParcel, outputParcel, binder, callback)); 88 } 89 runTfliteInterpreter( InferenceInputParcel inputParcel, InferenceOutputParcel outputParcel, IDataAccessService binder, IIsolatedModelServiceCallback callback)90 private void runTfliteInterpreter( 91 InferenceInputParcel inputParcel, 92 InferenceOutputParcel outputParcel, 93 IDataAccessService binder, 94 IIsolatedModelServiceCallback callback) { 95 try { 96 Trace.beginSection("IsolatedModelService#RunInference"); 97 Object[] inputs = convertToObjArray(inputParcel.getInputData().getList()); 98 if (inputs.length == 0) { 99 sendError(callback); 100 } 101 102 ModelId modelId = inputParcel.getModelId(); 103 ParcelFileDescriptor modelFd = fetchModel(binder, modelId); 104 ByteBuffer byteBuffer = IoUtils.getByteBufferFromFd(modelFd); 105 if (byteBuffer == null) { 106 closeFd(modelFd); 107 sendError(callback); 108 } 109 InterpreterApi interpreter = 110 InterpreterApi.create( 111 byteBuffer, 112 new InterpreterApi.Options() 113 .setNumThreads(inputParcel.getCpuNumThread())); 114 Map<Integer, Object> outputs = outputParcel.getData(); 115 if (outputs.isEmpty() || inputs.length == 0) { 116 closeFd(modelFd); 117 sendError(callback); 118 } 119 120 // TODO(b/323469981): handle batch size better. Currently TFLite will throws error if 121 // batchSize doesn't match input data size. 122 int batchSize = inputParcel.getBatchSize(); 123 for (int i = 0; i < interpreter.getInputTensorCount(); i++) { 124 Tensor tensor = interpreter.getInputTensor(i); 125 int[] shape = tensor.shape(); 126 shape[0] = batchSize; 127 interpreter.resizeInput(i, shape); 128 } 129 interpreter.runForMultipleInputsOutputs(inputs, outputs); 130 interpreter.close(); 131 132 closeFd(modelFd); 133 Bundle bundle = new Bundle(); 134 InferenceOutput result = new InferenceOutput.Builder().setDataOutputs(outputs).build(); 135 bundle.putParcelable(Constants.EXTRA_RESULT, new InferenceOutputParcel(result)); 136 sendResult(bundle, callback); 137 Trace.endSection(); 138 } catch (Exception e) { 139 // Catch all exceptions including TFLite errors. 140 sLogger.e(e, TAG + ": Failed to run inference job."); 141 sendError(callback); 142 } 143 } 144 convertToObjArray(List<byte[]> input)145 private Object[] convertToObjArray(List<byte[]> input) { 146 Object[] output = new Object[input.size()]; 147 for (int i = 0; i < input.size(); i++) { 148 ByteArrayInputStream bais = new ByteArrayInputStream(input.get(i)); 149 try { 150 ObjectInputStream ois = new ObjectInputStream(bais); 151 output[i] = ois.readObject(); 152 } catch (IOException | ClassNotFoundException e) { 153 throw new RuntimeException(e); 154 } 155 } 156 return output; 157 } 158 closeFd(ParcelFileDescriptor fd)159 private void closeFd(ParcelFileDescriptor fd) { 160 try { 161 fd.close(); 162 } catch (IOException e) { 163 sLogger.e(e, TAG + ": Failed to close model file descriptor"); 164 } 165 } 166 fetchModel(IDataAccessService dataAccessService, ModelId modelId)167 private ParcelFileDescriptor fetchModel(IDataAccessService dataAccessService, ModelId modelId) { 168 try { 169 sLogger.d(TAG + ": Start fetch model %s %d", modelId.getKey(), modelId.getTableId()); 170 BlockingQueue<Bundle> asyncResult = new ArrayBlockingQueue<>(1); 171 Bundle params = new Bundle(); 172 params.putParcelable(Constants.EXTRA_MODEL_ID, modelId); 173 dataAccessService.onRequest( 174 Constants.DATA_ACCESS_OP_GET_MODEL, 175 params, 176 new IDataAccessServiceCallback.Stub() { 177 @Override 178 public void onSuccess(Bundle result) { 179 if (result != null) { 180 asyncResult.add(result); 181 } else { 182 asyncResult.add(Bundle.EMPTY); 183 } 184 } 185 186 @Override 187 public void onError(int errorCode) { 188 asyncResult.add(Bundle.EMPTY); 189 } 190 }); 191 Bundle result = asyncResult.take(); 192 ParcelFileDescriptor modelFd = 193 result.getParcelable(Constants.EXTRA_RESULT, ParcelFileDescriptor.class); 194 Objects.requireNonNull(modelFd); 195 return modelFd; 196 } catch (InterruptedException | RemoteException e) { 197 sLogger.e(TAG + ": Failed to fetch model from DataAccessService", e); 198 throw new IllegalStateException(e); 199 } 200 } 201 sendError(@onNull IIsolatedModelServiceCallback callback)202 private void sendError(@NonNull IIsolatedModelServiceCallback callback) { 203 try { 204 callback.onError(Constants.STATUS_INTERNAL_ERROR); 205 } catch (RemoteException e) { 206 sLogger.e(TAG + ": Callback error", e); 207 } 208 } 209 sendResult( @onNull Bundle result, @NonNull IIsolatedModelServiceCallback callback)210 private void sendResult( 211 @NonNull Bundle result, @NonNull IIsolatedModelServiceCallback callback) { 212 try { 213 callback.onSuccess(result); 214 } catch (RemoteException e) { 215 sLogger.e(e, TAG + ": Callback error"); 216 } 217 } 218 219 @VisibleForTesting 220 static class Injector { getExecutor()221 ListeningExecutorService getExecutor() { 222 return OnDevicePersonalizationExecutors.getBackgroundExecutor(); 223 } 224 } 225 } 226