1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
20 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
21 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
22 import android.annotation.CallbackExecutor;
23 import android.annotation.FlaggedApi;
24 import android.annotation.NonNull;
25 import android.annotation.WorkerThread;
26 import android.os.Bundle;
27 import android.os.OutcomeReceiver;
28 import android.os.RemoteException;
29 
30 import com.android.adservices.ondevicepersonalization.flags.Flags;
31 import com.android.ondevicepersonalization.internal.util.LoggerFactory;
32 
33 import java.util.Objects;
34 import java.util.concurrent.Executor;
35 
36 /**
37  * Handles model inference and only support TFLite model inference now. See {@link
38  * IsolatedService#getModelManager}.
39  */
40 @FlaggedApi(Flags.FLAG_ON_DEVICE_PERSONALIZATION_APIS_ENABLED)
41 public class ModelManager {
42     private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
43     private static final String TAG = ModelManager.class.getSimpleName();
44     @NonNull private final IDataAccessService mDataService;
45 
46     @NonNull private final IIsolatedModelService mModelService;
47 
48     /** @hide */
ModelManager( @onNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService)49     public ModelManager(
50             @NonNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService) {
51         mDataService = dataService;
52         mModelService = modelService;
53     }
54 
55     /**
56      * Run a single model inference. Only supports TFLite model inference now.
57      *
58      * @param input contains all the information needed for a run of model inference.
59      * @param executor the {@link Executor} on which to invoke the callback.
60      * @param receiver this returns a {@link InferenceOutput} which contains model inference result
61      *     or {@link Exception} on failure.
62      */
63     @WorkerThread
run( @onNull InferenceInput input, @NonNull @CallbackExecutor Executor executor, @NonNull OutcomeReceiver<InferenceOutput, Exception> receiver)64     public void run(
65             @NonNull InferenceInput input,
66             @NonNull @CallbackExecutor Executor executor,
67             @NonNull OutcomeReceiver<InferenceOutput, Exception> receiver) {
68         final long startTimeMillis = System.currentTimeMillis();
69         Objects.requireNonNull(input);
70         Bundle bundle = new Bundle();
71         bundle.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, mDataService.asBinder());
72         bundle.putParcelable(Constants.EXTRA_INFERENCE_INPUT, new InferenceInputParcel(input));
73         try {
74             mModelService.runInference(
75                     bundle,
76                     new IIsolatedModelServiceCallback.Stub() {
77                         @Override
78                         public void onSuccess(Bundle result) {
79                             executor.execute(
80                                     () -> {
81                                         int responseCode = Constants.STATUS_SUCCESS;
82                                         long endTimeMillis = System.currentTimeMillis();
83                                         try {
84                                             InferenceOutputParcel outputParcel =
85                                                     Objects.requireNonNull(
86                                                             result.getParcelable(
87                                                                     Constants.EXTRA_RESULT,
88                                                                     InferenceOutputParcel.class));
89                                             InferenceOutput output =
90                                                     new InferenceOutput(outputParcel.getData());
91                                             endTimeMillis = System.currentTimeMillis();
92                                             receiver.onResult(output);
93                                         } catch (Exception e) {
94                                             endTimeMillis = System.currentTimeMillis();
95                                             responseCode = Constants.STATUS_INTERNAL_ERROR;
96                                             receiver.onError(e);
97                                         } finally {
98                                             logApiCallStats(
99                                                     Constants.API_NAME_MODEL_MANAGER_RUN,
100                                                     endTimeMillis - startTimeMillis,
101                                                     responseCode);
102                                         }
103                                     });
104                         }
105 
106                         @Override
107                         public void onError(int errorCode) {
108                             executor.execute(
109                                     () -> {
110                                         long endTimeMillis = System.currentTimeMillis();
111                                         receiver.onError(
112                                             new IllegalStateException("Error: " + errorCode));
113                                         logApiCallStats(
114                                                 Constants.API_NAME_MODEL_MANAGER_RUN,
115                                                 endTimeMillis - startTimeMillis,
116                                                 Constants.STATUS_INTERNAL_ERROR);
117                                     });
118                         }
119                     });
120         } catch (RemoteException e) {
121             receiver.onError(new IllegalStateException(e));
122         }
123     }
124 
logApiCallStats(int apiName, long duration, int responseCode)125     private void logApiCallStats(int apiName, long duration, int responseCode) {
126         try {
127             mDataService.logApiCallStats(apiName, duration, responseCode);
128         } catch (Exception e) {
129             sLogger.d(e, TAG + ": failed to log metrics");
130         }
131     }
132 }
133