/* * Copyright (C) 2024 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package android.adservices.ondevicepersonalization; import android.annotation.FlaggedApi; import android.annotation.IntDef; import android.annotation.IntRange; import android.annotation.NonNull; import android.annotation.SuppressLint; import com.android.adservices.ondevicepersonalization.flags.Flags; import com.android.ondevicepersonalization.internal.util.AnnotationValidations; import com.android.ondevicepersonalization.internal.util.DataClass; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; /** * Contains all the information needed for a run of model inference. The input of {@link * ModelManager#run}. */ @FlaggedApi(Flags.FLAG_ON_DEVICE_PERSONALIZATION_APIS_ENABLED) @DataClass(genBuilder = true, genEqualsHashCode = true) public final class InferenceInput { /** The configuration that controls runtime interpreter behavior. */ @NonNull private Params mParams; /** * An array of input data. The inputs should be in the same order as inputs of the model. * *
For example, if a model takes multiple inputs: * *
{@code * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. * Object[] inputData = {input0, input1, ...}; * }* * For TFLite, this field is mapped to inputs of runForMultipleInputsOutputs: * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 */ @NonNull private Object[] mInputData; /** * The number of input examples. Adopter can set this field to run batching inference. The batch * size is 1 by default. The batch size should match the input data size. */ private int mBatchSize = 1; /** * The empty InferenceOutput representing the expected output structure. For TFLite, the * inference code will verify whether this expected output structure matches model output * signature. * *
If a model produce string tensors: * *
{@code * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. * HashMap*/ @NonNull private InferenceOutput mExpectedOutputStructure; @DataClass(genBuilder = true, genHiddenConstructor = true, genEqualsHashCode = true) public static class Params { /** * A {@link KeyValueStore} where pre-trained model is stored. Only supports TFLite model * now. */ @NonNull private KeyValueStore mKeyValueStore; /** * The key of the table where the corresponding value stores a pre-trained model. Only * supports TFLite model now. */ @NonNull private String mModelKey; /** The model inference will run on CPU. */ public static final int DELEGATE_CPU = 1; /** * The delegate to run model inference. * * @hide */ @IntDef( prefix = "DELEGATE_", value = {DELEGATE_CPU}) @Retention(RetentionPolicy.SOURCE) public @interface Delegate {} /** * The delegate to run model inference. If not set, the default value is {@link * #DELEGATE_CPU}. */ private @Delegate int mDelegateType = DELEGATE_CPU; /** The model is a tensorflow lite model. */ public static final int MODEL_TYPE_TENSORFLOW_LITE = 1; /** * The type of the model. * * @hide */ @IntDef( prefix = "MODEL_TYPE", value = {MODEL_TYPE_TENSORFLOW_LITE}) @Retention(RetentionPolicy.SOURCE) public @interface ModelType {} /** * The type of the pre-trained model. If not set, the default value is {@link * #MODEL_TYPE_TENSORFLOW_LITE} . Only supports {@link #MODEL_TYPE_TENSORFLOW_LITE} for now. */ private @ModelType int mModelType = MODEL_TYPE_TENSORFLOW_LITE; /** * The number of threads used for intraop parallelism on CPU, must be positive number. * Adopters can set this field based on model architecture. The actual thread number depends * on system resources and other constraints. */ private @IntRange(from = 1) int mRecommendedNumThreads = 1; // Code below generated by codegen v1.0.23. // // DO NOT MODIFY! // CHECKSTYLE:OFF Generated code // // To regenerate run: // $ codegen // $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/InferenceInput.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control // @formatter:off /** * Creates a new Params. * * @param keyValueStore A {@link KeyValueStore} where pre-trained model is stored. Only * supports TFLite model now. * @param modelKey The key of the table where the corresponding value stores a pre-trained * model. Only supports TFLite model now. * @param delegateType The delegate to run model inference. If not set, the default value is * {@link #DELEGATE_CPU}. * @param modelType The type of the pre-trained model. If not set, the default value is * {@link #MODEL_TYPE_TENSORFLOW_LITE} . Only supports {@link * #MODEL_TYPE_TENSORFLOW_LITE} for now. * @param recommendedNumThreads The number of threads used for intraop parallelism on CPU, * must be positive number. Adopters can set this field based on model architecture. The * actual thread number depends on system resources and other constraints. * @hide */ @DataClass.Generated.Member public Params( @NonNull KeyValueStore keyValueStore, @NonNull String modelKey, @Delegate int delegateType, @ModelType int modelType, @IntRange(from = 1) int recommendedNumThreads) { this.mKeyValueStore = keyValueStore; AnnotationValidations.validate(NonNull.class, null, mKeyValueStore); this.mModelKey = modelKey; AnnotationValidations.validate(NonNull.class, null, mModelKey); this.mDelegateType = delegateType; AnnotationValidations.validate(Delegate.class, null, mDelegateType); this.mModelType = modelType; AnnotationValidations.validate(ModelType.class, null, mModelType); this.mRecommendedNumThreads = recommendedNumThreads; AnnotationValidations.validate(IntRange.class, null, mRecommendedNumThreads, "from", 1); // onConstructed(); // You can define this method to get a callback } /** * A {@link KeyValueStore} where pre-trained model is stored. Only supports TFLite model * now. */ @DataClass.Generated.Member public @NonNull KeyValueStore getKeyValueStore() { return mKeyValueStore; } /** * The key of the table where the corresponding value stores a pre-trained model. Only * supports TFLite model now. */ @DataClass.Generated.Member public @NonNull String getModelKey() { return mModelKey; } /** * The delegate to run model inference. If not set, the default value is {@link * #DELEGATE_CPU}. */ @DataClass.Generated.Member public @Delegate int getDelegateType() { return mDelegateType; } /** * The type of the pre-trained model. If not set, the default value is {@link * #MODEL_TYPE_TENSORFLOW_LITE} . Only supports {@link #MODEL_TYPE_TENSORFLOW_LITE} for now. */ @DataClass.Generated.Member public @ModelType int getModelType() { return mModelType; } /** * The number of threads used for intraop parallelism on CPU, must be positive number. * Adopters can set this field based on model architecture. The actual thread number depends * on system resources and other constraints. */ @DataClass.Generated.Member public @IntRange(from = 1) int getRecommendedNumThreads() { return mRecommendedNumThreads; } @Override @DataClass.Generated.Member public boolean equals(@android.annotation.Nullable Object o) { // You can override field equality logic by defining either of the methods like: // boolean fieldNameEquals(Params other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; @SuppressWarnings("unchecked") Params that = (Params) o; //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mKeyValueStore, that.mKeyValueStore) && java.util.Objects.equals(mModelKey, that.mModelKey) && mDelegateType == that.mDelegateType && mModelType == that.mModelType && mRecommendedNumThreads == that.mRecommendedNumThreads; } @Override @DataClass.Generated.Member public int hashCode() { // You can override field hashCode logic by defining methods like: // int fieldNameHashCode() { ... } int _hash = 1; _hash = 31 * _hash + java.util.Objects.hashCode(mKeyValueStore); _hash = 31 * _hash + java.util.Objects.hashCode(mModelKey); _hash = 31 * _hash + mDelegateType; _hash = 31 * _hash + mModelType; _hash = 31 * _hash + mRecommendedNumThreads; return _hash; } /** A builder for {@link Params} */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { private @NonNull KeyValueStore mKeyValueStore; private @NonNull String mModelKey; private @Delegate int mDelegateType; private @ModelType int mModelType; private @IntRange(from = 1) int mRecommendedNumThreads; private long mBuilderFieldsSet = 0L; /** * Creates a new Builder. * * @param keyValueStore A {@link KeyValueStore} where pre-trained model is stored. Only * supports TFLite model now. * @param modelKey The key of the table where the corresponding value stores a * pre-trained model. Only supports TFLite model now. */ public Builder(@NonNull KeyValueStore keyValueStore, @NonNull String modelKey) { mKeyValueStore = keyValueStore; AnnotationValidations.validate(NonNull.class, null, mKeyValueStore); mModelKey = modelKey; AnnotationValidations.validate(NonNull.class, null, mModelKey); } /** * A {@link KeyValueStore} where pre-trained model is stored. Only supports TFLite model * now. */ @DataClass.Generated.Member public @NonNull Builder setKeyValueStore(@NonNull KeyValueStore value) { mBuilderFieldsSet |= 0x1; mKeyValueStore = value; return this; } /** * The key of the table where the corresponding value stores a pre-trained model. Only * supports TFLite model now. */ @DataClass.Generated.Member public @NonNull Builder setModelKey(@NonNull String value) { mBuilderFieldsSet |= 0x2; mModelKey = value; return this; } /** * The delegate to run model inference. If not set, the default value is {@link * #DELEGATE_CPU}. */ @DataClass.Generated.Member public @NonNull Builder setDelegateType(@Delegate int value) { mBuilderFieldsSet |= 0x4; mDelegateType = value; return this; } /** * The type of the pre-trained model. If not set, the default value is {@link * #MODEL_TYPE_TENSORFLOW_LITE} . Only supports {@link #MODEL_TYPE_TENSORFLOW_LITE} for * now. */ @DataClass.Generated.Member public @NonNull Builder setModelType(@ModelType int value) { mBuilderFieldsSet |= 0x8; mModelType = value; return this; } /** * The number of threads used for intraop parallelism on CPU, must be positive number. * Adopters can set this field based on model architecture. The actual thread number * depends on system resources and other constraints. */ @DataClass.Generated.Member public @NonNull Builder setRecommendedNumThreads(@IntRange(from = 1) int value) { mBuilderFieldsSet |= 0x10; mRecommendedNumThreads = value; return this; } /** Builds the instance. */ public @NonNull Params build() { mBuilderFieldsSet |= 0x20; // Mark builder used if ((mBuilderFieldsSet & 0x4) == 0) { mDelegateType = DELEGATE_CPU; } if ((mBuilderFieldsSet & 0x8) == 0) { mModelType = MODEL_TYPE_TENSORFLOW_LITE; } if ((mBuilderFieldsSet & 0x10) == 0) { mRecommendedNumThreads = 1; } Params o = new Params( mKeyValueStore, mModelKey, mDelegateType, mModelType, mRecommendedNumThreads); return o; } } @DataClass.Generated( time = 1709250081597L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/InferenceInput.java", inputSignatures = "private @android.annotation.NonNull android.adservices.ondevicepersonalization.KeyValueStore mKeyValueStore\nprivate @android.annotation.NonNull java.lang.String mModelKey\npublic static final int DELEGATE_CPU\nprivate @android.adservices.ondevicepersonalization.Params.Delegate int mDelegateType\npublic static final int MODEL_TYPE_TENSORFLOW_LITE\nprivate @android.adservices.ondevicepersonalization.Params.ModelType int mModelType\nprivate @android.annotation.IntRange int mRecommendedNumThreads\nclass Params extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genHiddenConstructor=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} // @formatter:on // End of generated code } // Code below generated by codegen v1.0.23. // // DO NOT MODIFY! // CHECKSTYLE:OFF Generated code // // To regenerate run: // $ codegen // $ANDROID_BUILD_TOP/packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/InferenceInput.java // // To exclude the generated code from IntelliJ auto-formatting enable (one-time): // Settings > Editor > Code Style > Formatter Control // @formatter:off @DataClass.Generated.Member /* package-private */ InferenceInput( @NonNull Params params, @NonNull Object[] inputData, int batchSize, @NonNull InferenceOutput expectedOutputStructure) { this.mParams = params; AnnotationValidations.validate(NonNull.class, null, mParams); this.mInputData = inputData; AnnotationValidations.validate(NonNull.class, null, mInputData); this.mBatchSize = batchSize; this.mExpectedOutputStructure = expectedOutputStructure; AnnotationValidations.validate(NonNull.class, null, mExpectedOutputStructure); // onConstructed(); // You can define this method to get a callback } /** The configuration that controls runtime interpreter behavior. */ @DataClass.Generated.Member public @NonNull Params getParams() { return mParams; } /** * An array of input data. The inputs should be in the same order as inputs of the model. * *outputs = new HashMap<>(); * outputs.put(0, output); * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); * }
For example, if a model takes multiple inputs: * *
{@code * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. * Object[] inputData = {input0, input1, ...}; * }* * For TFLite, this field is mapped to inputs of runForMultipleInputsOutputs: * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 */ @SuppressLint("ArrayReturn") @DataClass.Generated.Member public @NonNull Object[] getInputData() { return mInputData; } /** * The number of input examples. Adopter can set this field to run batching inference. The batch * size is 1 by default. The batch size should match the input data size. */ @DataClass.Generated.Member public int getBatchSize() { return mBatchSize; } /** * The empty InferenceOutput representing the expected output structure. For TFLite, the * inference code will verify whether this expected output structure matches model output * signature. * *
If a model produce string tensors: * *
{@code * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. * HashMap*/ @DataClass.Generated.Member public @NonNull InferenceOutput getExpectedOutputStructure() { return mExpectedOutputStructure; } @Override @DataClass.Generated.Member public boolean equals(@android.annotation.Nullable Object o) { // You can override field equality logic by defining either of the methods like: // boolean fieldNameEquals(InferenceInput other) { ... } // boolean fieldNameEquals(FieldType otherValue) { ... } if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; @SuppressWarnings("unchecked") InferenceInput that = (InferenceInput) o; //noinspection PointlessBooleanExpression return true && java.util.Objects.equals(mParams, that.mParams) && java.util.Arrays.equals(mInputData, that.mInputData) && mBatchSize == that.mBatchSize && java.util.Objects.equals( mExpectedOutputStructure, that.mExpectedOutputStructure); } @Override @DataClass.Generated.Member public int hashCode() { // You can override field hashCode logic by defining methods like: // int fieldNameHashCode() { ... } int _hash = 1; _hash = 31 * _hash + java.util.Objects.hashCode(mParams); _hash = 31 * _hash + java.util.Arrays.hashCode(mInputData); _hash = 31 * _hash + mBatchSize; _hash = 31 * _hash + java.util.Objects.hashCode(mExpectedOutputStructure); return _hash; } /** A builder for {@link InferenceInput} */ @SuppressWarnings("WeakerAccess") @DataClass.Generated.Member public static final class Builder { private @NonNull Params mParams; private @NonNull Object[] mInputData; private int mBatchSize; private @NonNull InferenceOutput mExpectedOutputStructure; private long mBuilderFieldsSet = 0L; /** * Creates a new Builder. * * @param params The configuration that controls runtime interpreter behavior. * @param inputData An array of input data. The inputs should be in the same order as inputs * of the model. *outputs = new HashMap<>(); * outputs.put(0, output); * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); * }
For example, if a model takes multiple inputs: *
{@code * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. * Object[] inputData = {input0, input1, ...}; * * }* For TFLite, this field is mapped to inputs of runForMultipleInputsOutputs: * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 * @param expectedOutputStructure The empty InferenceOutput representing the expected output * structure. For TFLite, the inference code will verify whether this expected output * structure matches model output signature. *
If a model produce string tensors: *
{@code * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. * HashMap*/ public Builder( @NonNull Params params, @SuppressLint("ArrayReturn") @NonNull Object[] inputData, @NonNull InferenceOutput expectedOutputStructure) { mParams = params; AnnotationValidations.validate(NonNull.class, null, mParams); mInputData = inputData; AnnotationValidations.validate(NonNull.class, null, mInputData); mExpectedOutputStructure = expectedOutputStructure; AnnotationValidations.validate(NonNull.class, null, mExpectedOutputStructure); } /** The configuration that controls runtime interpreter behavior. */ @DataClass.Generated.Member public @NonNull Builder setParams(@NonNull Params value) { mBuilderFieldsSet |= 0x1; mParams = value; return this; } /** * An array of input data. The inputs should be in the same order as inputs of the model. * *outputs = new HashMap<>(); * outputs.put(0, output); * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); * * }
For example, if a model takes multiple inputs: * *
{@code * String[] input0 = {"foo", "bar"}; // string tensor shape is [2]. * int[] input1 = new int[]{3, 2, 1}; // int tensor shape is [3]. * Object[] inputData = {input0, input1, ...}; * }* * For TFLite, this field is mapped to inputs of runForMultipleInputsOutputs: * https://www.tensorflow.org/lite/api_docs/java/org/tensorflow/lite/InterpreterApi#parameters_9 */ @DataClass.Generated.Member public @NonNull Builder setInputData(@NonNull Object... value) { mBuilderFieldsSet |= 0x2; mInputData = value; return this; } /** * The number of input examples. Adopter can set this field to run batching inference. The * batch size is 1 by default. The batch size should match the input data size. */ @DataClass.Generated.Member public @NonNull Builder setBatchSize(int value) { mBuilderFieldsSet |= 0x4; mBatchSize = value; return this; } /** * The empty InferenceOutput representing the expected output structure. For TFLite, the * inference code will verify whether this expected output structure matches model output * signature. * *
If a model produce string tensors: * *
{@code * String[] output = new String[3][2]; // Output tensor shape is [3, 2]. * HashMap*/ @DataClass.Generated.Member public @NonNull Builder setExpectedOutputStructure(@NonNull InferenceOutput value) { mBuilderFieldsSet |= 0x8; mExpectedOutputStructure = value; return this; } /** Builds the instance. */ public @NonNull InferenceInput build() { mBuilderFieldsSet |= 0x10; // Mark builder used if ((mBuilderFieldsSet & 0x4) == 0) { mBatchSize = 1; } InferenceInput o = new InferenceInput(mParams, mInputData, mBatchSize, mExpectedOutputStructure); return o; } } @DataClass.Generated( time = 1709250081618L, codegenVersion = "1.0.23", sourceFile = "packages/modules/OnDevicePersonalization/framework/java/android/adservices/ondevicepersonalization/InferenceInput.java", inputSignatures = "private @android.annotation.NonNull android.adservices.ondevicepersonalization.Params mParams\nprivate @android.annotation.NonNull java.lang.Object[] mInputData\nprivate int mBatchSize\nprivate @android.annotation.NonNull android.adservices.ondevicepersonalization.InferenceOutput mExpectedOutputStructure\nclass InferenceInput extends java.lang.Object implements []\n@com.android.ondevicepersonalization.internal.util.DataClass(genBuilder=true, genEqualsHashCode=true)") @Deprecated private void __metadata() {} // @formatter:on // End of generated code }outputs = new HashMap<>(); * outputs.put(0, output); * expectedOutputStructure = new InferenceOutput.Builder().setDataOutputs(outputs).build(); * }