1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.ondevicepersonalization; 18 19 import android.adservices.ondevicepersonalization.aidl.IDataAccessService; 20 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService; 21 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback; 22 import android.annotation.CallbackExecutor; 23 import android.annotation.FlaggedApi; 24 import android.annotation.NonNull; 25 import android.annotation.WorkerThread; 26 import android.os.Bundle; 27 import android.os.OutcomeReceiver; 28 import android.os.RemoteException; 29 30 import com.android.adservices.ondevicepersonalization.flags.Flags; 31 import com.android.ondevicepersonalization.internal.util.LoggerFactory; 32 33 import java.util.Objects; 34 import java.util.concurrent.Executor; 35 36 /** 37 * Handles model inference and only support TFLite model inference now. See {@link 38 * IsolatedService#getModelManager}. 39 */ 40 @FlaggedApi(Flags.FLAG_ON_DEVICE_PERSONALIZATION_APIS_ENABLED) 41 public class ModelManager { 42 private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); 43 private static final String TAG = ModelManager.class.getSimpleName(); 44 @NonNull private final IDataAccessService mDataService; 45 46 @NonNull private final IIsolatedModelService mModelService; 47 48 /** @hide */ ModelManager( @onNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService)49 public ModelManager( 50 @NonNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService) { 51 mDataService = dataService; 52 mModelService = modelService; 53 } 54 55 /** 56 * Run a single model inference. Only supports TFLite model inference now. 57 * 58 * @param input contains all the information needed for a run of model inference. 59 * @param executor the {@link Executor} on which to invoke the callback. 60 * @param receiver this returns a {@link InferenceOutput} which contains model inference result 61 * or {@link Exception} on failure. 62 */ 63 @WorkerThread run( @onNull InferenceInput input, @NonNull @CallbackExecutor Executor executor, @NonNull OutcomeReceiver<InferenceOutput, Exception> receiver)64 public void run( 65 @NonNull InferenceInput input, 66 @NonNull @CallbackExecutor Executor executor, 67 @NonNull OutcomeReceiver<InferenceOutput, Exception> receiver) { 68 final long startTimeMillis = System.currentTimeMillis(); 69 Objects.requireNonNull(input); 70 Bundle bundle = new Bundle(); 71 bundle.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, mDataService.asBinder()); 72 bundle.putParcelable(Constants.EXTRA_INFERENCE_INPUT, new InferenceInputParcel(input)); 73 try { 74 mModelService.runInference( 75 bundle, 76 new IIsolatedModelServiceCallback.Stub() { 77 @Override 78 public void onSuccess(Bundle result) { 79 executor.execute( 80 () -> { 81 int responseCode = Constants.STATUS_SUCCESS; 82 long endTimeMillis = System.currentTimeMillis(); 83 try { 84 InferenceOutputParcel outputParcel = 85 Objects.requireNonNull( 86 result.getParcelable( 87 Constants.EXTRA_RESULT, 88 InferenceOutputParcel.class)); 89 InferenceOutput output = 90 new InferenceOutput(outputParcel.getData()); 91 endTimeMillis = System.currentTimeMillis(); 92 receiver.onResult(output); 93 } catch (Exception e) { 94 endTimeMillis = System.currentTimeMillis(); 95 responseCode = Constants.STATUS_INTERNAL_ERROR; 96 receiver.onError(e); 97 } finally { 98 logApiCallStats( 99 Constants.API_NAME_MODEL_MANAGER_RUN, 100 endTimeMillis - startTimeMillis, 101 responseCode); 102 } 103 }); 104 } 105 106 @Override 107 public void onError(int errorCode) { 108 executor.execute( 109 () -> { 110 long endTimeMillis = System.currentTimeMillis(); 111 receiver.onError( 112 new IllegalStateException("Error: " + errorCode)); 113 logApiCallStats( 114 Constants.API_NAME_MODEL_MANAGER_RUN, 115 endTimeMillis - startTimeMillis, 116 Constants.STATUS_INTERNAL_ERROR); 117 }); 118 } 119 }); 120 } catch (RemoteException e) { 121 receiver.onError(new IllegalStateException(e)); 122 } 123 } 124 logApiCallStats(int apiName, long duration, int responseCode)125 private void logApiCallStats(int apiName, long duration, int responseCode) { 126 try { 127 mDataService.logApiCallStats(apiName, duration, responseCode); 128 } catch (Exception e) { 129 sLogger.d(e, TAG + ": failed to log metrics"); 130 } 131 } 132 } 133