1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 
18 package com.android.intentresolver.model;
19 
20 import android.app.usage.UsageStats;
21 import android.content.ComponentName;
22 import android.content.Context;
23 import android.content.Intent;
24 import android.content.ServiceConnection;
25 import android.content.pm.ActivityInfo;
26 import android.content.pm.ApplicationInfo;
27 import android.content.pm.PackageManager;
28 import android.content.pm.PackageManager.NameNotFoundException;
29 import android.content.pm.ResolveInfo;
30 import android.metrics.LogMaker;
31 import android.os.Handler;
32 import android.os.IBinder;
33 import android.os.Message;
34 import android.os.RemoteException;
35 import android.os.UserHandle;
36 import android.service.resolver.IResolverRankerResult;
37 import android.service.resolver.IResolverRankerService;
38 import android.service.resolver.ResolverRankerService;
39 import android.service.resolver.ResolverTarget;
40 import android.util.Log;
41 
42 import androidx.annotation.Nullable;
43 
44 import com.android.intentresolver.ResolvedComponentInfo;
45 import com.android.intentresolver.chooser.TargetInfo;
46 import com.android.intentresolver.logging.EventLog;
47 import com.android.internal.logging.MetricsLogger;
48 import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
49 
50 import com.google.android.collect.Lists;
51 
52 import java.lang.ref.WeakReference;
53 import java.text.Collator;
54 import java.util.ArrayList;
55 import java.util.Comparator;
56 import java.util.HashMap;
57 import java.util.LinkedHashMap;
58 import java.util.List;
59 import java.util.Map;
60 import java.util.concurrent.CountDownLatch;
61 import java.util.concurrent.TimeUnit;
62 
63 /**
64  * Ranks and compares packages based on usage stats and uses the {@link ResolverRankerService}.
65  */
66 public class ResolverRankerServiceResolverComparator extends AbstractResolverComparator {
67     private static final String TAG = "RRSResolverComparator";
68 
69     private static final boolean DEBUG = false;
70 
71     // One week
72     private static final long USAGE_STATS_PERIOD = 1000 * 60 * 60 * 24 * 7;
73 
74     private static final long RECENCY_TIME_PERIOD = 1000 * 60 * 60 * 12;
75 
76     private static final float RECENCY_MULTIPLIER = 2.f;
77 
78     // timeout for establishing connections with a ResolverRankerService.
79     private static final int CONNECTION_COST_TIMEOUT_MILLIS = 200;
80 
81     private final Collator mCollator;
82     private final Map<UserHandle, Map<String, UsageStats>> mStatsPerUser;
83     private final long mCurrentTime;
84     private final long mSinceTime;
85     private final Map<UserHandle, Map<ComponentName, ResolverTarget>> mTargetsDictPerUser;
86     private final String mReferrerPackage;
87     private final Object mLock = new Object();
88     private ArrayList<ResolverTarget> mTargets;
89     private String mAction;
90     private ComponentName mResolvedRankerName;
91     private ComponentName mRankerServiceName;
92     private IResolverRankerService mRanker;
93     private ResolverRankerServiceConnection mConnection;
94     private Context mContext;
95     private CountDownLatch mConnectSignal;
96     private ResolverRankerServiceComparatorModel mComparatorModel;
97 
98     /**
99      * Constructor to initialize the comparator.
100      * @param launchedFromContext the activity calling this comparator
101      * @param intent original intent
102      * @param targetUserSpace the userSpace(s) used by the comparator for fetching activity stats
103      *                        and recording activity selection. The latter could be different from
104      *                        the userSpace provided by context.
105      */
ResolverRankerServiceResolverComparator(Context launchedFromContext, Intent intent, String referrerPackage, Runnable afterCompute, EventLog eventLog, UserHandle targetUserSpace, ComponentName promoteToFirst)106     public ResolverRankerServiceResolverComparator(Context launchedFromContext, Intent intent,
107                                                    String referrerPackage, Runnable afterCompute,
108                                                    EventLog eventLog, UserHandle targetUserSpace,
109                                                    ComponentName promoteToFirst) {
110         this(launchedFromContext, intent, referrerPackage, afterCompute, eventLog,
111                 Lists.newArrayList(targetUserSpace), promoteToFirst);
112     }
113 
114     /**
115      * Constructor to initialize the comparator.
116      * @param launchedFromContext the activity calling this comparator
117      * @param intent original intent
118      * @param targetUserSpaceList the userSpace(s) used by the comparator for fetching activity
119      *                            stats and recording activity selection. The latter could be
120      *                            different from the userSpace provided by context.
121      */
ResolverRankerServiceResolverComparator(Context launchedFromContext, Intent intent, String referrerPackage, Runnable afterCompute, EventLog eventLog, List<UserHandle> targetUserSpaceList, @Nullable ComponentName promoteToFirst)122     public ResolverRankerServiceResolverComparator(Context launchedFromContext, Intent intent,
123             String referrerPackage, Runnable afterCompute, EventLog eventLog,
124             List<UserHandle> targetUserSpaceList, @Nullable ComponentName promoteToFirst) {
125         super(launchedFromContext, intent, targetUserSpaceList, promoteToFirst);
126         mCollator = Collator.getInstance(
127                 launchedFromContext.getResources().getConfiguration().locale);
128         mReferrerPackage = referrerPackage;
129         mContext = launchedFromContext;
130 
131         mCurrentTime = System.currentTimeMillis();
132         mSinceTime = mCurrentTime - USAGE_STATS_PERIOD;
133         mStatsPerUser = new HashMap<>();
134         mTargetsDictPerUser = new HashMap<>();
135         for (UserHandle user : targetUserSpaceList) {
136             mStatsPerUser.put(
137                     user,
138                     mUsmMap.get(user).queryAndAggregateUsageStats(mSinceTime, mCurrentTime));
139             mTargetsDictPerUser.put(user, new LinkedHashMap<>());
140         }
141         mAction = intent.getAction();
142         mRankerServiceName = new ComponentName(mContext, this.getClass());
143         setCallBack(afterCompute);
144         setEventLog(eventLog);
145 
146         mComparatorModel = buildUpdatedModel();
147     }
148 
149     @Override
handleResultMessage(Message msg)150     public void handleResultMessage(Message msg) {
151         if (msg.what != RANKER_SERVICE_RESULT) {
152             return;
153         }
154         if (msg.obj == null) {
155             Log.e(TAG, "Receiving null prediction results.");
156             return;
157         }
158         final List<ResolverTarget> receivedTargets = (List<ResolverTarget>) msg.obj;
159         if (receivedTargets != null && mTargets != null
160                     && receivedTargets.size() == mTargets.size()) {
161             final int size = mTargets.size();
162             boolean isUpdated = false;
163             for (int i = 0; i < size; ++i) {
164                 final float predictedProb =
165                         receivedTargets.get(i).getSelectProbability();
166                 if (predictedProb != mTargets.get(i).getSelectProbability()) {
167                     mTargets.get(i).setSelectProbability(predictedProb);
168                     isUpdated = true;
169                 }
170             }
171             if (isUpdated) {
172                 mRankerServiceName = mResolvedRankerName;
173                 mComparatorModel = buildUpdatedModel();
174             }
175         } else {
176             Log.e(TAG, "Sizes of sent and received ResolverTargets diff.");
177         }
178     }
179 
180     // compute features for each target according to usage stats of targets.
181     @Override
doCompute(List<ResolvedComponentInfo> targets)182     public void doCompute(List<ResolvedComponentInfo> targets) {
183         final long recentSinceTime = mCurrentTime - RECENCY_TIME_PERIOD;
184 
185         float mostRecencyScore = 1.0f;
186         float mostTimeSpentScore = 1.0f;
187         float mostLaunchScore = 1.0f;
188         float mostChooserScore = 1.0f;
189 
190         for (ResolvedComponentInfo target : targets) {
191             if (target.getResolveInfoAt(0) == null) {
192                 continue;
193             }
194             final ResolverTarget resolverTarget = new ResolverTarget();
195             final UserHandle resolvedComponentUserSpace =
196                     target.getResolveInfoAt(0).userHandle;
197             final Map<ComponentName, ResolverTarget> targetsDict =
198                     mTargetsDictPerUser.get(resolvedComponentUserSpace);
199             final Map<String, UsageStats> stats = mStatsPerUser.get(resolvedComponentUserSpace);
200             if (targetsDict != null && stats != null) {
201                 targetsDict.put(target.name, resolverTarget);
202                 final UsageStats pkStats = stats.get(target.name.getPackageName());
203                 if (pkStats != null) {
204                     // Only count recency for apps that weren't the caller
205                     // since the caller is always the most recent.
206                     // Persistent processes muck this up, so omit them too.
207                     if (!target.name.getPackageName().equals(mReferrerPackage)
208                             && !isPersistentProcess(target)) {
209                         final float recencyScore =
210                                 (float) Math.max(pkStats.getLastTimeUsed() - recentSinceTime, 0);
211                         resolverTarget.setRecencyScore(recencyScore);
212                         if (recencyScore > mostRecencyScore) {
213                             mostRecencyScore = recencyScore;
214                         }
215                     }
216                     final float timeSpentScore = (float) pkStats.getTotalTimeInForeground();
217                     resolverTarget.setTimeSpentScore(timeSpentScore);
218                     if (timeSpentScore > mostTimeSpentScore) {
219                         mostTimeSpentScore = timeSpentScore;
220                     }
221                     final float launchScore = (float) pkStats.mLaunchCount;
222                     resolverTarget.setLaunchScore(launchScore);
223                     if (launchScore > mostLaunchScore) {
224                         mostLaunchScore = launchScore;
225                     }
226 
227                     float chooserScore = 0.0f;
228                     if (pkStats.mChooserCounts != null && mAction != null
229                             && pkStats.mChooserCounts.get(mAction) != null) {
230                         chooserScore = (float) pkStats.mChooserCounts.get(mAction)
231                                 .getOrDefault(mContentType, 0);
232                         if (mAnnotations != null) {
233                             final int size = mAnnotations.length;
234                             for (int i = 0; i < size; i++) {
235                                 chooserScore += (float) pkStats.mChooserCounts.get(mAction)
236                                         .getOrDefault(mAnnotations[i], 0);
237                             }
238                         }
239                     }
240                     if (DEBUG) {
241                         if (mAction == null) {
242                             Log.d(TAG, "Action type is null");
243                         } else {
244                             Log.d(TAG, "Chooser Count of " + mAction + ":"
245                                     + target.name.getPackageName() + " is "
246                                     + Float.toString(chooserScore));
247                         }
248                     }
249                     resolverTarget.setChooserScore(chooserScore);
250                     if (chooserScore > mostChooserScore) {
251                         mostChooserScore = chooserScore;
252                     }
253                 }
254             }
255         }
256 
257         if (DEBUG) {
258             Log.d(TAG, "compute - mostRecencyScore: " + mostRecencyScore
259                     + " mostTimeSpentScore: " + mostTimeSpentScore
260                     + " mostLaunchScore: " + mostLaunchScore
261                     + " mostChooserScore: " + mostChooserScore);
262         }
263 
264         mTargets = new ArrayList<>();
265         for (UserHandle u : mTargetsDictPerUser.keySet()) {
266             mTargets.addAll(mTargetsDictPerUser.get(u).values());
267         }
268         for (ResolverTarget target : mTargets) {
269             final float recency = target.getRecencyScore() / mostRecencyScore;
270             setFeatures(target, recency * recency * RECENCY_MULTIPLIER,
271                     target.getLaunchScore() / mostLaunchScore,
272                     target.getTimeSpentScore() / mostTimeSpentScore,
273                     target.getChooserScore() / mostChooserScore);
274             addDefaultSelectProbability(target);
275             if (DEBUG) {
276                 Log.d(TAG, "Scores: " + target);
277             }
278         }
279         predictSelectProbabilities(mTargets);
280 
281         mComparatorModel = buildUpdatedModel();
282     }
283 
284     @Override
compare(ResolveInfo lhs, ResolveInfo rhs)285     public int compare(ResolveInfo lhs, ResolveInfo rhs) {
286         return mComparatorModel.getComparator().compare(lhs, rhs);
287     }
288 
289     @Override
getScore(TargetInfo targetInfo)290     public float getScore(TargetInfo targetInfo) {
291         return mComparatorModel.getScore(targetInfo);
292     }
293 
294     // update ranking model when the connection to it is valid.
295     @Override
updateModel(TargetInfo targetInfo)296     public void updateModel(TargetInfo targetInfo) {
297         synchronized (mLock) {
298             mComparatorModel.notifyOnTargetSelected(targetInfo);
299         }
300     }
301 
302     // unbind the service and clear unhandled messges.
303     @Override
destroy()304     public void destroy() {
305         mHandler.removeMessages(RANKER_SERVICE_RESULT);
306         mHandler.removeMessages(RANKER_RESULT_TIMEOUT);
307         if (mConnection != null) {
308             mContext.unbindService(mConnection);
309             mConnection.destroy();
310         }
311         afterCompute();
312         if (DEBUG) {
313             Log.d(TAG, "Unbinded Resolver Ranker.");
314         }
315     }
316 
317     // connect to a ranking service.
initRanker(Context context)318     private void initRanker(Context context) {
319         synchronized (mLock) {
320             if (mConnection != null && mRanker != null) {
321                 if (DEBUG) {
322                     Log.d(TAG, "Ranker still exists; reusing the existing one.");
323                 }
324                 return;
325             }
326         }
327         Intent intent = resolveRankerService();
328         if (intent == null) {
329             return;
330         }
331         mConnectSignal = new CountDownLatch(1);
332         mConnection = new ResolverRankerServiceConnection(mConnectSignal);
333         context.bindServiceAsUser(intent, mConnection, Context.BIND_AUTO_CREATE, UserHandle.SYSTEM);
334     }
335 
336     // resolve the service for ranking.
resolveRankerService()337     private Intent resolveRankerService() {
338         Intent intent = new Intent(ResolverRankerService.SERVICE_INTERFACE);
339         final List<ResolveInfo> resolveInfos = mContext.getPackageManager()
340                 .queryIntentServices(intent, 0);
341         for (ResolveInfo resolveInfo : resolveInfos) {
342             if (resolveInfo == null || resolveInfo.serviceInfo == null
343                     || resolveInfo.serviceInfo.applicationInfo == null) {
344                 if (DEBUG) {
345                     Log.d(TAG, "Failed to retrieve a ranker: " + resolveInfo);
346                 }
347                 continue;
348             }
349             ComponentName componentName = new ComponentName(
350                     resolveInfo.serviceInfo.applicationInfo.packageName,
351                     resolveInfo.serviceInfo.name);
352             try {
353                 final String perm =
354                         mContext.getPackageManager().getServiceInfo(componentName, 0).permission;
355                 if (!ResolverRankerService.BIND_PERMISSION.equals(perm)) {
356                     Log.w(TAG, "ResolverRankerService " + componentName + " does not require"
357                             + " permission " + ResolverRankerService.BIND_PERMISSION
358                             + " - this service will not be queried for "
359                             + "ResolverRankerServiceResolverComparator. add android:permission=\""
360                             + ResolverRankerService.BIND_PERMISSION + "\""
361                             + " to the <service> tag for " + componentName
362                             + " in the manifest.");
363                     continue;
364                 }
365                 if (PackageManager.PERMISSION_GRANTED != mContext.getPackageManager()
366                         .checkPermission(ResolverRankerService.HOLD_PERMISSION,
367                                 resolveInfo.serviceInfo.packageName)) {
368                     Log.w(TAG, "ResolverRankerService " + componentName + " does not hold"
369                             + " permission " + ResolverRankerService.HOLD_PERMISSION
370                             + " - this service will not be queried for "
371                             + "ResolverRankerServiceResolverComparator.");
372                     continue;
373                 }
374             } catch (NameNotFoundException e) {
375                 Log.e(TAG, "Could not look up service " + componentName
376                         + "; component name not found");
377                 continue;
378             }
379             if (DEBUG) {
380                 Log.d(TAG, "Succeeded to retrieve a ranker: " + componentName);
381             }
382             mResolvedRankerName = componentName;
383             intent.setComponent(componentName);
384             return intent;
385         }
386         return null;
387     }
388 
389     private class ResolverRankerServiceConnection implements ServiceConnection {
390         private final CountDownLatch mConnectSignal;
391 
ResolverRankerServiceConnection(CountDownLatch connectSignal)392         ResolverRankerServiceConnection(CountDownLatch connectSignal) {
393             mConnectSignal = connectSignal;
394         }
395 
396         public final IResolverRankerResult resolverRankerResult =
397                 new ResolverRankerResultCallback(mLock, mHandler);
398 
399         @Override
onServiceConnected(ComponentName name, IBinder service)400         public void onServiceConnected(ComponentName name, IBinder service) {
401             if (DEBUG) {
402                 Log.d(TAG, "onServiceConnected: " + name);
403             }
404             synchronized (mLock) {
405                 mRanker = IResolverRankerService.Stub.asInterface(service);
406                 mComparatorModel = buildUpdatedModel();
407                 mConnectSignal.countDown();
408             }
409         }
410 
411         @Override
onServiceDisconnected(ComponentName name)412         public void onServiceDisconnected(ComponentName name) {
413             if (DEBUG) {
414                 Log.d(TAG, "onServiceDisconnected: " + name);
415             }
416             synchronized (mLock) {
417                 destroy();
418             }
419         }
420 
destroy()421         public void destroy() {
422             synchronized (mLock) {
423                 mRanker = null;
424                 mComparatorModel = buildUpdatedModel();
425             }
426         }
427     }
428 
429     private static class ResolverRankerResultCallback extends IResolverRankerResult.Stub {
430         private final Object mLock;
431         private final WeakReference<Handler> mHandlerRef;
432 
ResolverRankerResultCallback(Object lock, Handler handler)433         private ResolverRankerResultCallback(Object lock, Handler handler) {
434             mLock = lock;
435             mHandlerRef = new WeakReference<>(handler);
436         }
437 
438         @Override
sendResult(List<ResolverTarget> targets)439         public void sendResult(List<ResolverTarget> targets) throws RemoteException {
440             if (DEBUG) {
441                 Log.d(TAG, "Sending Result back to Resolver: " + targets);
442             }
443             synchronized (mLock) {
444                 final Message msg = Message.obtain();
445                 msg.what = RANKER_SERVICE_RESULT;
446                 msg.obj = targets;
447                 Handler handler = mHandlerRef.get();
448                 if (handler != null) {
449                     handler.sendMessage(msg);
450                 }
451             }
452         }
453     }
454 
455     @Override
beforeCompute()456     void beforeCompute() {
457         super.beforeCompute();
458         for (UserHandle userHandle : mTargetsDictPerUser.keySet()) {
459             mTargetsDictPerUser.get(userHandle).clear();
460         }
461         mTargets = null;
462         mRankerServiceName = new ComponentName(mContext, this.getClass());
463         mComparatorModel = buildUpdatedModel();
464         mResolvedRankerName = null;
465         initRanker(mContext);
466     }
467 
468     // predict select probabilities if ranking service is valid.
predictSelectProbabilities(List<ResolverTarget> targets)469     private void predictSelectProbabilities(List<ResolverTarget> targets) {
470         if (mConnection == null) {
471             if (DEBUG) {
472                 Log.d(TAG, "Has not found valid ResolverRankerService; Skip Prediction");
473             }
474         } else {
475             try {
476                 mConnectSignal.await(CONNECTION_COST_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
477                 synchronized (mLock) {
478                     if (mRanker != null) {
479                         mRanker.predict(targets, mConnection.resolverRankerResult);
480                         return;
481                     } else {
482                         if (DEBUG) {
483                             Log.d(TAG, "Ranker has not been initialized; skip predict.");
484                         }
485                     }
486                 }
487             } catch (InterruptedException e) {
488                 Log.e(TAG, "Error in Wait for Service Connection.");
489             } catch (RemoteException e) {
490                 Log.e(TAG, "Error in Predict: " + e);
491             }
492         }
493         afterCompute();
494     }
495 
496     // adds select prob as the default values, according to a pre-trained Logistic Regression model.
addDefaultSelectProbability(ResolverTarget target)497     private void addDefaultSelectProbability(ResolverTarget target) {
498         float sum = (2.5543f * target.getLaunchScore())
499                 + (2.8412f * target.getTimeSpentScore())
500                 + (0.269f * target.getRecencyScore())
501                 + (4.2222f * target.getChooserScore());
502         target.setSelectProbability((float) (1.0 / (1.0 + Math.exp(1.6568f - sum))));
503     }
504 
505     // sets features for each target
setFeatures(ResolverTarget target, float recencyScore, float launchScore, float timeSpentScore, float chooserScore)506     private void setFeatures(ResolverTarget target, float recencyScore, float launchScore,
507                              float timeSpentScore, float chooserScore) {
508         target.setRecencyScore(recencyScore);
509         target.setLaunchScore(launchScore);
510         target.setTimeSpentScore(timeSpentScore);
511         target.setChooserScore(chooserScore);
512     }
513 
isPersistentProcess(ResolvedComponentInfo rci)514     static boolean isPersistentProcess(ResolvedComponentInfo rci) {
515         if (rci != null && rci.getCount() > 0) {
516             int flags = rci.getResolveInfoAt(0).activityInfo.applicationInfo.flags;
517             return (flags & ApplicationInfo.FLAG_PERSISTENT) != 0;
518         }
519         return false;
520     }
521 
522     /**
523      * Re-construct a {@code ResolverRankerServiceComparatorModel} to replace the current model
524      * instance (if any) using the up-to-date {@code ResolverRankerServiceResolverComparator} ivar
525      * values.
526      *
527      * TODO: each time we replace the model instance, we're either updating the model to use
528      * adjusted data (which is appropriate), or we're providing a (late) value for one of our ivars
529      * that wasn't available the last time the model was updated. For those latter cases, we should
530      * just avoid creating the model altogether until we have all the prerequisites we'll need. Then
531      * we can probably simplify the logic in {@code ResolverRankerServiceComparatorModel} since we
532      * won't need to handle edge cases when the model data isn't fully prepared.
533      * (In some cases, these kinds of "updates" might interleave -- e.g., we might have finished
534      * initializing the first time and now want to adjust some data, but still need to wait for
535      * changes to propagate to the other ivars before rebuilding the model.)
536      */
buildUpdatedModel()537     private ResolverRankerServiceComparatorModel buildUpdatedModel() {
538         // TODO: we don't currently guarantee that the underlying target list/map won't be mutated,
539         // so the ResolverComparatorModel may provide inconsistent results. We should make immutable
540         // copies of the data (waiting for any necessary remaining data before creating the model).
541         return new ResolverRankerServiceComparatorModel(
542                 mStatsPerUser,
543                 mTargetsDictPerUser,
544                 mTargets,
545                 mCollator,
546                 mRanker,
547                 mRankerServiceName,
548                 (mAnnotations != null),
549                 mPmMap);
550     }
551 
552     /**
553      * Implementation of a {@code ResolverComparatorModel} that provides the same ranking logic as
554      * the legacy {@code ResolverRankerServiceResolverComparator}, as a refactoring step toward
555      * removing the complex legacy API.
556      */
557     static class ResolverRankerServiceComparatorModel implements ResolverComparatorModel {
558         private final Map<UserHandle, Map<String, UsageStats>> mStatsPerUser; // Treat as immutable.
559         // Treat as immutable.
560         private final Map<UserHandle, Map<ComponentName, ResolverTarget>> mTargetsDictPerUser;
561         private final List<ResolverTarget> mTargets;  // Treat as immutable.
562         private final Collator mCollator;
563         private final IResolverRankerService mRanker;
564         private final ComponentName mRankerServiceName;
565         private final boolean mAnnotationsUsed;
566         private final Map<UserHandle, PackageManager> mPmMap;
567 
568         // TODO: it doesn't look like we should have to pass both targets and targetsDict, but it's
569         // not written in a way that makes it clear whether we can derive one from the other (at
570         // least in this constructor).
ResolverRankerServiceComparatorModel( Map<UserHandle, Map<String, UsageStats>> statsPerUser, Map<UserHandle, Map<ComponentName, ResolverTarget>> targetsDictPerUser, List<ResolverTarget> targets, Collator collator, IResolverRankerService ranker, ComponentName rankerServiceName, boolean annotationsUsed, Map<UserHandle, PackageManager> pmMap)571         ResolverRankerServiceComparatorModel(
572                 Map<UserHandle, Map<String, UsageStats>> statsPerUser,
573                 Map<UserHandle, Map<ComponentName, ResolverTarget>> targetsDictPerUser,
574                 List<ResolverTarget> targets,
575                 Collator collator,
576                 IResolverRankerService ranker,
577                 ComponentName rankerServiceName,
578                 boolean annotationsUsed,
579                 Map<UserHandle, PackageManager> pmMap) {
580             mStatsPerUser = statsPerUser;
581             mTargetsDictPerUser = targetsDictPerUser;
582             mTargets = targets;
583             mCollator = collator;
584             mRanker = ranker;
585             mRankerServiceName = rankerServiceName;
586             mAnnotationsUsed = annotationsUsed;
587             mPmMap = pmMap;
588         }
589 
590         @Override
getComparator()591         public Comparator<ResolveInfo> getComparator() {
592             // TODO: doCompute() doesn't seem to be concerned about null-checking mStats. Is that
593             // a bug there, or do we have a way of knowing it will be non-null under certain
594             // conditions?
595             return (lhs, rhs) -> {
596                 final ResolverTarget lhsTarget =
597                         getActivityResolverTargetForUser(lhs.activityInfo, lhs.userHandle);
598                 final ResolverTarget rhsTarget =
599                         getActivityResolverTargetForUser(rhs.activityInfo, rhs.userHandle);
600 
601                 if (lhsTarget != null && rhsTarget != null) {
602                     final int selectProbabilityDiff = Float.compare(
603                             rhsTarget.getSelectProbability(), lhsTarget.getSelectProbability());
604 
605                     if (selectProbabilityDiff != 0) {
606                         return selectProbabilityDiff > 0 ? 1 : -1;
607                     }
608                 }
609 
610                 CharSequence sa = null;
611                 if (mPmMap.containsKey(lhs.userHandle)) {
612                     sa = lhs.loadLabel(mPmMap.get(lhs.userHandle));
613                 }
614                 if (sa == null) sa = lhs.activityInfo.name;
615                 CharSequence sb = null;
616                 if (mPmMap.containsKey(rhs.userHandle)) {
617                     sb = rhs.loadLabel(mPmMap.get(rhs.userHandle));
618                 }
619                 if (sb == null) sb = rhs.activityInfo.name;
620 
621                 return mCollator.compare(sa.toString().trim(), sb.toString().trim());
622             };
623         }
624 
625         @Override
getScore(TargetInfo targetInfo)626         public float getScore(TargetInfo targetInfo) {
627             ResolverTarget target = getResolverTargetForUserAndComponent(
628                     targetInfo.getResolvedComponentName(), targetInfo.getResolveInfo().userHandle);
629             if (target != null) {
630                 return target.getSelectProbability();
631             }
632             return 0;
633         }
634 
635         @Override
notifyOnTargetSelected(TargetInfo targetInfo)636         public void notifyOnTargetSelected(TargetInfo targetInfo) {
637             if (mRanker != null) {
638                 try {
639                     int selectedPos = -1;
640                     if (mTargetsDictPerUser.containsKey(targetInfo.getResolveInfo().userHandle)) {
641                         selectedPos = new ArrayList<>(mTargetsDictPerUser
642                                 .get(targetInfo.getResolveInfo().userHandle).keySet())
643                                 .indexOf(targetInfo.getResolvedComponentName());
644                     }
645                     if (selectedPos >= 0 && mTargets != null) {
646                         final float selectedProbability = getScore(targetInfo);
647                         int order = 0;
648                         for (ResolverTarget target : mTargets) {
649                             if (target.getSelectProbability() > selectedProbability) {
650                                 order++;
651                             }
652                         }
653                         logMetrics(order);
654                         mRanker.train(mTargets, selectedPos);
655                     } else {
656                         if (DEBUG) {
657                             Log.d(TAG, "Selected a unknown component: " + targetInfo
658                                     .getResolvedComponentName());
659                         }
660                     }
661                 } catch (RemoteException e) {
662                     Log.e(TAG, "Error in Train: " + e);
663                 }
664             } else {
665                 if (DEBUG) {
666                     Log.d(TAG, "Ranker is null; skip updateModel.");
667                 }
668             }
669         }
670 
671         /** Records metrics for evaluation. */
logMetrics(int selectedPos)672         private void logMetrics(int selectedPos) {
673             if (mRankerServiceName != null) {
674                 MetricsLogger metricsLogger = new MetricsLogger();
675                 LogMaker log = new LogMaker(MetricsEvent.ACTION_TARGET_SELECTED);
676                 log.setComponentName(mRankerServiceName);
677                 log.addTaggedData(MetricsEvent.FIELD_IS_CATEGORY_USED, mAnnotationsUsed ? 1 : 0);
678                 log.addTaggedData(MetricsEvent.FIELD_RANKED_POSITION, selectedPos);
679                 metricsLogger.write(log);
680             }
681         }
682 
683         @Nullable
getActivityResolverTargetForUser( ActivityInfo activity, UserHandle user)684         private ResolverTarget getActivityResolverTargetForUser(
685                 ActivityInfo activity, UserHandle user) {
686             return getResolverTargetForUserAndComponent(
687                     new ComponentName(activity.packageName, activity.name), user);
688         }
689 
690         @Nullable
getResolverTargetForUserAndComponent( ComponentName targetComponentName, UserHandle user)691         private ResolverTarget getResolverTargetForUserAndComponent(
692                 ComponentName targetComponentName, UserHandle user) {
693             if ((mStatsPerUser == null) || !mTargetsDictPerUser.containsKey(user)) {
694                 return null;
695             }
696             return mTargetsDictPerUser.get(user).get(targetComponentName);
697         }
698     }
699 }
700