1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.service.topics.classifier; 18 19 import static java.util.stream.Collectors.toMap; 20 import static java.util.stream.Collectors.toSet; 21 22 import android.annotation.NonNull; 23 import android.os.Build; 24 25 import androidx.annotation.RequiresApi; 26 27 import com.android.adservices.LoggerFactory; 28 import com.android.adservices.data.topics.Topic; 29 import com.android.adservices.service.Flags; 30 import com.android.adservices.service.Flags.ClassifierType; 31 import com.android.adservices.service.FlagsFactory; 32 import com.android.adservices.service.stats.AdServicesLoggerImpl; 33 import com.android.adservices.service.topics.CacheManager; 34 import com.android.internal.annotations.VisibleForTesting; 35 36 import com.google.common.base.Supplier; 37 import com.google.common.base.Suppliers; 38 39 import java.util.List; 40 import java.util.Map; 41 import java.util.Map.Entry; 42 import java.util.Random; 43 import java.util.Set; 44 import java.util.stream.Stream; 45 46 /** 47 * Manager class to control the classifier behaviour between available types of classifier based on 48 * classifier flags. 49 */ 50 @RequiresApi(Build.VERSION_CODES.S) 51 public class ClassifierManager implements Classifier { 52 private static final LoggerFactory.Logger sLogger = LoggerFactory.getTopicsLogger(); 53 private static ClassifierManager sSingleton; 54 55 private Supplier<OnDeviceClassifier> mOnDeviceClassifier; 56 private Supplier<PrecomputedClassifier> mPrecomputedClassifier; 57 58 @VisibleForTesting ClassifierManager( @onNull Supplier<OnDeviceClassifier> onDeviceClassifier, @NonNull Supplier<PrecomputedClassifier> precomputedClassifier)59 ClassifierManager( 60 @NonNull Supplier<OnDeviceClassifier> onDeviceClassifier, 61 @NonNull Supplier<PrecomputedClassifier> precomputedClassifier) { 62 mOnDeviceClassifier = onDeviceClassifier; 63 mPrecomputedClassifier = precomputedClassifier; 64 } 65 66 /** Returns the singleton instance of the {@link ClassifierManager} given a context. */ 67 @NonNull getInstance()68 public static ClassifierManager getInstance() { 69 synchronized (ClassifierManager.class) { 70 if (sSingleton == null) { 71 // Note: we need to have a singleton ModelManager shared by both Classifiers. 72 sSingleton = 73 new ClassifierManager( 74 Suppliers.memoize( 75 () -> 76 new OnDeviceClassifier( 77 new Random(), 78 ModelManager.getInstance(), 79 CacheManager.getInstance(), 80 ClassifierInputManager.getInstance(), 81 AdServicesLoggerImpl.getInstance())), 82 Suppliers.memoize( 83 () -> 84 new PrecomputedClassifier( 85 ModelManager.getInstance(), 86 CacheManager.getInstance(), 87 AdServicesLoggerImpl.getInstance()))); 88 } 89 } 90 return sSingleton; 91 } 92 93 /** 94 * {@inheritDoc} 95 * 96 * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values. 97 */ 98 @Override classify(Set<String> apps)99 public Map<String, List<Topic>> classify(Set<String> apps) { 100 Flags flags = FlagsFactory.getFlags(); 101 102 if (flags.getTopicsOnDeviceClassifierKillSwitch()) { 103 sLogger.v( 104 "On-device classifier disabled via topics on device classifier kill switch - " 105 + "falling back to precomputed classifier"); 106 return mPrecomputedClassifier.get().classify(apps); 107 } 108 109 @ClassifierType int classifierTypeFlag = flags.getClassifierType(); 110 if (classifierTypeFlag == Flags.PRECOMPUTED_CLASSIFIER) { 111 sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_CLASSIFIER"); 112 return mPrecomputedClassifier.get().classify(apps); 113 } else if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) { 114 sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = ON_DEVICE_CLASSIFIER"); 115 return mOnDeviceClassifier.get().classify(apps); 116 } else { 117 sLogger.v( 118 "ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_THEN_ON_DEVICE"); 119 // PRECOMPUTED_THEN_ON_DEVICE 120 // Default if classifierTypeFlag value is not set/invalid. 121 // precomputedClassifications expects non-empty values. 122 Map<String, List<Topic>> precomputedClassifications = 123 mPrecomputedClassifier.get().classify(apps); 124 // Collect package names that do not have any topics in the precomputed list. 125 Set<String> remainingApps = 126 apps.stream() 127 .filter( 128 packageName -> 129 !isValidValue(packageName, precomputedClassifications)) 130 .collect(toSet()); 131 Map<String, List<Topic>> onDeviceClassifications = 132 mOnDeviceClassifier.get().classify(remainingApps); 133 134 // Combine classification values. On device classifications are used for values that 135 // do not have valid precomputed classifications. 136 Map<String, List<Topic>> combinedClassifications = 137 Stream.concat( 138 onDeviceClassifications.entrySet().stream(), 139 precomputedClassifications.entrySet().stream()) 140 .collect( 141 toMap( 142 Entry::getKey, 143 Entry::getValue, 144 ClassifierManager::combineTopics)); 145 return combinedClassifications; 146 } 147 } 148 149 /** 150 * {@inheritDoc} 151 * 152 * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values. 153 */ 154 @Override getTopTopics( Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics)155 public List<Topic> getTopTopics( 156 Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics) { 157 Flags flags = FlagsFactory.getFlags(); 158 159 if (flags.getTopicsOnDeviceClassifierKillSwitch()) { 160 sLogger.v( 161 "On-device classifier disabled via topics on device classifier kill switch - " 162 + "falling back to precomputed classifier"); 163 return mPrecomputedClassifier 164 .get() 165 .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics); 166 } 167 168 @ClassifierType int classifierTypeFlag = flags.getClassifierType(); 169 // getTopTopics has the same implementation. 170 // If the loaded assets are same, the output will be same. 171 if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) { 172 return mOnDeviceClassifier 173 .get() 174 .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics); 175 } else { 176 // Use getTopics from PrecomputedClassifier as default. 177 return mPrecomputedClassifier 178 .get() 179 .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics); 180 } 181 } 182 183 /** 184 * Gets the topics taxonomy based on the classifier type flag values. 185 * 186 * @return The topics taxonomy for the enabled classifier. 187 */ getTopicsTaxonomy()188 public List<Integer> getTopicsTaxonomy() { 189 Flags flags = FlagsFactory.getFlags(); 190 191 if (flags.getTopicsOnDeviceClassifierKillSwitch()) { 192 sLogger.v( 193 "On-device classifier disabled via topics on device classifier kill switch - " 194 + "falling back to precomputed classifier"); 195 return mPrecomputedClassifier.get().getLabels(); 196 } 197 198 @ClassifierType int classifierTypeFlag = flags.getClassifierType(); 199 // getLabels has the same implementation. 200 // If the loaded assets are same, the output will be same. 201 if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) { 202 return mOnDeviceClassifier.get().getLabels(); 203 } else { 204 // Use getLabels from PrecomputedClassifier as default. 205 return mPrecomputedClassifier.get().getLabels(); 206 } 207 } 208 209 // Prefer precomputed values for topics if the list is not empty. combineTopics( List<Topic> onDeviceValue, List<Topic> precomputedValue)210 private static List<Topic> combineTopics( 211 List<Topic> onDeviceValue, List<Topic> precomputedValue) { 212 if (!precomputedValue.isEmpty()) { 213 return precomputedValue; 214 } 215 return onDeviceValue; 216 } 217 218 // Return true if package name has non-empty list of topics in the classifications. isValidValue(String packageName, Map<String, List<Topic>> classifications)219 private boolean isValidValue(String packageName, Map<String, List<Topic>> classifications) { 220 if (classifications.containsKey(packageName) 221 && !classifications.get(packageName).isEmpty()) { 222 return true; 223 } 224 return false; 225 } 226 } 227