1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.service.topics.classifier;
18 
19 import static java.util.stream.Collectors.toMap;
20 import static java.util.stream.Collectors.toSet;
21 
22 import android.annotation.NonNull;
23 import android.os.Build;
24 
25 import androidx.annotation.RequiresApi;
26 
27 import com.android.adservices.LoggerFactory;
28 import com.android.adservices.data.topics.Topic;
29 import com.android.adservices.service.Flags;
30 import com.android.adservices.service.Flags.ClassifierType;
31 import com.android.adservices.service.FlagsFactory;
32 import com.android.adservices.service.stats.AdServicesLoggerImpl;
33 import com.android.adservices.service.topics.CacheManager;
34 import com.android.internal.annotations.VisibleForTesting;
35 
36 import com.google.common.base.Supplier;
37 import com.google.common.base.Suppliers;
38 
39 import java.util.List;
40 import java.util.Map;
41 import java.util.Map.Entry;
42 import java.util.Random;
43 import java.util.Set;
44 import java.util.stream.Stream;
45 
46 /**
47  * Manager class to control the classifier behaviour between available types of classifier based on
48  * classifier flags.
49  */
50 @RequiresApi(Build.VERSION_CODES.S)
51 public class ClassifierManager implements Classifier {
52     private static final LoggerFactory.Logger sLogger = LoggerFactory.getTopicsLogger();
53     private static ClassifierManager sSingleton;
54 
55     private Supplier<OnDeviceClassifier> mOnDeviceClassifier;
56     private Supplier<PrecomputedClassifier> mPrecomputedClassifier;
57 
58     @VisibleForTesting
ClassifierManager( @onNull Supplier<OnDeviceClassifier> onDeviceClassifier, @NonNull Supplier<PrecomputedClassifier> precomputedClassifier)59     ClassifierManager(
60             @NonNull Supplier<OnDeviceClassifier> onDeviceClassifier,
61             @NonNull Supplier<PrecomputedClassifier> precomputedClassifier) {
62         mOnDeviceClassifier = onDeviceClassifier;
63         mPrecomputedClassifier = precomputedClassifier;
64     }
65 
66     /** Returns the singleton instance of the {@link ClassifierManager} given a context. */
67     @NonNull
getInstance()68     public static ClassifierManager getInstance() {
69         synchronized (ClassifierManager.class) {
70             if (sSingleton == null) {
71                 // Note: we need to have a singleton ModelManager shared by both Classifiers.
72                 sSingleton =
73                         new ClassifierManager(
74                                 Suppliers.memoize(
75                                         () ->
76                                                 new OnDeviceClassifier(
77                                                         new Random(),
78                                                         ModelManager.getInstance(),
79                                                         CacheManager.getInstance(),
80                                                         ClassifierInputManager.getInstance(),
81                                                         AdServicesLoggerImpl.getInstance())),
82                                 Suppliers.memoize(
83                                         () ->
84                                                 new PrecomputedClassifier(
85                                                         ModelManager.getInstance(),
86                                                         CacheManager.getInstance(),
87                                                         AdServicesLoggerImpl.getInstance())));
88             }
89         }
90         return sSingleton;
91     }
92 
93     /**
94      * {@inheritDoc}
95      *
96      * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values.
97      */
98     @Override
classify(Set<String> apps)99     public Map<String, List<Topic>> classify(Set<String> apps) {
100         Flags flags = FlagsFactory.getFlags();
101 
102         if (flags.getTopicsOnDeviceClassifierKillSwitch()) {
103             sLogger.v(
104                     "On-device classifier disabled via topics on device classifier kill switch - "
105                             + "falling back to precomputed classifier");
106             return mPrecomputedClassifier.get().classify(apps);
107         }
108 
109         @ClassifierType int classifierTypeFlag = flags.getClassifierType();
110         if (classifierTypeFlag == Flags.PRECOMPUTED_CLASSIFIER) {
111             sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_CLASSIFIER");
112             return mPrecomputedClassifier.get().classify(apps);
113         } else if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) {
114             sLogger.v("ClassifierTypeFlag: " + classifierTypeFlag + " = ON_DEVICE_CLASSIFIER");
115             return mOnDeviceClassifier.get().classify(apps);
116         } else {
117             sLogger.v(
118                     "ClassifierTypeFlag: " + classifierTypeFlag + " = PRECOMPUTED_THEN_ON_DEVICE");
119             // PRECOMPUTED_THEN_ON_DEVICE
120             // Default if classifierTypeFlag value is not set/invalid.
121             // precomputedClassifications expects non-empty values.
122             Map<String, List<Topic>> precomputedClassifications =
123                     mPrecomputedClassifier.get().classify(apps);
124             // Collect package names that do not have any topics in the precomputed list.
125             Set<String> remainingApps =
126                     apps.stream()
127                             .filter(
128                                     packageName ->
129                                             !isValidValue(packageName, precomputedClassifications))
130                             .collect(toSet());
131             Map<String, List<Topic>> onDeviceClassifications =
132                     mOnDeviceClassifier.get().classify(remainingApps);
133 
134             // Combine classification values. On device classifications are used for values that
135             // do not have valid precomputed classifications.
136             Map<String, List<Topic>> combinedClassifications =
137                     Stream.concat(
138                                     onDeviceClassifications.entrySet().stream(),
139                                     precomputedClassifications.entrySet().stream())
140                             .collect(
141                                     toMap(
142                                             Entry::getKey,
143                                             Entry::getValue,
144                                             ClassifierManager::combineTopics));
145             return combinedClassifications;
146         }
147     }
148 
149     /**
150      * {@inheritDoc}
151      *
152      * <p>Invokes a particular {@link Classifier} instance based on the classifier type flag values.
153      */
154     @Override
getTopTopics( Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics)155     public List<Topic> getTopTopics(
156             Map<String, List<Topic>> appTopics, int numberOfTopTopics, int numberOfRandomTopics) {
157         Flags flags = FlagsFactory.getFlags();
158 
159         if (flags.getTopicsOnDeviceClassifierKillSwitch()) {
160             sLogger.v(
161                     "On-device classifier disabled via topics on device classifier kill switch - "
162                             + "falling back to precomputed classifier");
163             return mPrecomputedClassifier
164                     .get()
165                     .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics);
166         }
167 
168         @ClassifierType int classifierTypeFlag = flags.getClassifierType();
169         // getTopTopics has the same implementation.
170         // If the loaded assets are same, the output will be same.
171         if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) {
172             return mOnDeviceClassifier
173                     .get()
174                     .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics);
175         } else {
176             // Use getTopics from PrecomputedClassifier as default.
177             return mPrecomputedClassifier
178                     .get()
179                     .getTopTopics(appTopics, numberOfTopTopics, numberOfRandomTopics);
180         }
181     }
182 
183     /**
184      * Gets the topics taxonomy based on the classifier type flag values.
185      *
186      * @return The topics taxonomy for the enabled classifier.
187      */
getTopicsTaxonomy()188     public List<Integer> getTopicsTaxonomy() {
189         Flags flags = FlagsFactory.getFlags();
190 
191         if (flags.getTopicsOnDeviceClassifierKillSwitch()) {
192             sLogger.v(
193                     "On-device classifier disabled via topics on device classifier kill switch - "
194                             + "falling back to precomputed classifier");
195             return mPrecomputedClassifier.get().getLabels();
196         }
197 
198         @ClassifierType int classifierTypeFlag = flags.getClassifierType();
199         // getLabels has the same implementation.
200         // If the loaded assets are same, the output will be same.
201         if (classifierTypeFlag == Flags.ON_DEVICE_CLASSIFIER) {
202             return mOnDeviceClassifier.get().getLabels();
203         } else {
204             // Use getLabels from PrecomputedClassifier as default.
205             return mPrecomputedClassifier.get().getLabels();
206         }
207     }
208 
209     // Prefer precomputed values for topics if the list is not empty.
combineTopics( List<Topic> onDeviceValue, List<Topic> precomputedValue)210     private static List<Topic> combineTopics(
211             List<Topic> onDeviceValue, List<Topic> precomputedValue) {
212         if (!precomputedValue.isEmpty()) {
213             return precomputedValue;
214         }
215         return onDeviceValue;
216     }
217 
218     // Return true if package name has non-empty list of topics in the classifications.
isValidValue(String packageName, Map<String, List<Topic>> classifications)219     private boolean isValidValue(String packageName, Map<String, List<Topic>> classifications) {
220         if (classifications.containsKey(packageName)
221                 && !classifications.get(packageName).isEmpty()) {
222             return true;
223         }
224         return false;
225     }
226 }
227