1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.net;
18 
19 import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;
20 
21 import android.annotation.NonNull;
22 import android.app.usage.NetworkStatsManager;
23 import android.content.Context;
24 import android.content.pm.PackageManager;
25 import android.net.DataUsageRequest;
26 import android.net.NetworkIdentitySet;
27 import android.net.NetworkStack;
28 import android.net.NetworkStats;
29 import android.net.NetworkStatsAccess;
30 import android.net.NetworkStatsCollection;
31 import android.net.NetworkStatsHistory;
32 import android.net.NetworkTemplate;
33 import android.net.netstats.IUsageCallback;
34 import android.os.Handler;
35 import android.os.HandlerThread;
36 import android.os.IBinder;
37 import android.os.Looper;
38 import android.os.Message;
39 import android.os.Process;
40 import android.os.RemoteException;
41 import android.util.ArrayMap;
42 import android.util.IndentingPrintWriter;
43 import android.util.Log;
44 import android.util.SparseArray;
45 
46 import com.android.internal.annotations.VisibleForTesting;
47 import com.android.net.module.util.PerUidCounter;
48 
49 import java.util.concurrent.atomic.AtomicInteger;
50 
51 /**
52  * Manages observers of {@link NetworkStats}. Allows observers to be notified when
53  * data usage has been reported in {@link NetworkStatsService}. An observer can set
54  * a threshold of how much data it cares about to be notified.
55  */
56 class NetworkStatsObservers {
57     private static final String TAG = "NetworkStatsObservers";
58     private static final boolean LOG = true;
59     private static final boolean LOGV = false;
60 
61     private static final int MSG_REGISTER = 1;
62     private static final int MSG_UNREGISTER = 2;
63     private static final int MSG_UPDATE_STATS = 3;
64 
65     private static final int DUMP_USAGE_REQUESTS_COUNT = 200;
66 
67     // The maximum number of request allowed per uid before an exception is thrown.
68     @VisibleForTesting
69     static final int MAX_REQUESTS_PER_UID = 100;
70 
71     // All access to this map must be done from the handler thread.
72     // indexed by DataUsageRequest#requestId
73     private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
74 
75     // Request counters per uid, this is thread safe.
76     private final PerUidCounter mDataUsageRequestsPerUid = new PerUidCounter(MAX_REQUESTS_PER_UID);
77 
78     // Sequence number of DataUsageRequests
79     private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
80 
81     // Lazily instantiated when an observer is registered.
82     private volatile Handler mHandler;
83 
84     /**
85      * Creates a wrapper that contains the caller context and a normalized request.
86      * The request should be returned to the caller app, and the wrapper should be sent to this
87      * object through #addObserver by the service handler.
88      *
89      * <p>It will register the observer asynchronously, so it is safe to call from any thread.
90      *
91      * @return the normalized request wrapped within {@link RequestInfo}.
92      */
register(@onNull Context context, @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)93     public DataUsageRequest register(@NonNull Context context,
94             @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback,
95             int callingPid, int callingUid, @NonNull String callingPackage,
96             @NetworkStatsAccess.Level int accessLevel) {
97         DataUsageRequest request = buildRequest(context, inputRequest, callingUid);
98         RequestInfo requestInfo = buildRequestInfo(request, callback, callingPid, callingUid,
99                 callingPackage, accessLevel);
100         if (LOG) Log.d(TAG, "Registering observer for " + requestInfo);
101         mDataUsageRequestsPerUid.incrementCountOrThrow(callingUid);
102 
103         getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
104         return request;
105     }
106 
107     /**
108      * Unregister a data usage observer.
109      *
110      * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
111      */
unregister(DataUsageRequest request, int callingUid)112     public void unregister(DataUsageRequest request, int callingUid) {
113         getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
114                 request));
115     }
116 
117     /**
118      * Updates data usage statistics of registered observers and notifies if limits are reached.
119      *
120      * <p>It will update stats asynchronously, so it is safe to call from any thread.
121      */
updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, long currentTime)122     public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
123                 ArrayMap<String, NetworkIdentitySet> activeIfaces,
124                 ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
125                 long currentTime) {
126         StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
127                 activeUidIfaces, currentTime);
128         getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
129     }
130 
getHandler()131     private Handler getHandler() {
132         if (mHandler == null) {
133             synchronized (this) {
134                 if (mHandler == null) {
135                     if (LOGV) Log.v(TAG, "Creating handler");
136                     mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
137                 }
138             }
139         }
140         return mHandler;
141     }
142 
143     @VisibleForTesting
getHandlerLooperLocked()144     protected Looper getHandlerLooperLocked() {
145         // TODO: Currently, callbacks are dispatched on this thread if the caller register
146         //  callback without supplying a Handler. To ensure that the service handler thread
147         //  is not blocked by client code, the observers must create their own thread. Once
148         //  all callbacks are dispatched outside of the handler thread, the service handler
149         //  thread can be used here.
150         HandlerThread handlerThread = new HandlerThread(TAG);
151         handlerThread.start();
152         return handlerThread.getLooper();
153     }
154 
155     private Handler.Callback mHandlerCallback = new Handler.Callback() {
156         @Override
157         public boolean handleMessage(Message msg) {
158             switch (msg.what) {
159                 case MSG_REGISTER: {
160                     handleRegister((RequestInfo) msg.obj);
161                     return true;
162                 }
163                 case MSG_UNREGISTER: {
164                     handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
165                     return true;
166                 }
167                 case MSG_UPDATE_STATS: {
168                     handleUpdateStats((StatsContext) msg.obj);
169                     return true;
170                 }
171                 default: {
172                     return false;
173                 }
174             }
175         }
176     };
177 
178     /**
179      * Adds a {@link RequestInfo} as an observer.
180      * Should only be called from the handler thread otherwise there will be a race condition
181      * on mDataUsageRequests.
182      */
handleRegister(RequestInfo requestInfo)183     private void handleRegister(RequestInfo requestInfo) {
184         mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
185     }
186 
187     /**
188      * Removes a {@link DataUsageRequest} if the calling uid is authorized.
189      * Should only be called from the handler thread otherwise there will be a race condition
190      * on mDataUsageRequests.
191      */
handleUnregister(DataUsageRequest request, int callingUid)192     private void handleUnregister(DataUsageRequest request, int callingUid) {
193         RequestInfo requestInfo;
194         requestInfo = mDataUsageRequests.get(request.requestId);
195         if (requestInfo == null) {
196             if (LOG) Log.d(TAG, "Trying to unregister unknown request " + request);
197             return;
198         }
199         if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
200             Log.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
201             return;
202         }
203 
204         if (LOG) Log.d(TAG, "Unregistering " + requestInfo);
205         mDataUsageRequests.remove(request.requestId);
206         mDataUsageRequestsPerUid.decrementCountOrThrow(requestInfo.mCallingUid);
207         requestInfo.unlinkDeathRecipient();
208         requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
209     }
210 
handleUpdateStats(StatsContext statsContext)211     private void handleUpdateStats(StatsContext statsContext) {
212         if (mDataUsageRequests.size() == 0) {
213             return;
214         }
215 
216         for (int i = 0; i < mDataUsageRequests.size(); i++) {
217             RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
218             requestInfo.updateStats(statsContext);
219         }
220     }
221 
buildRequest(Context context, DataUsageRequest request, int callingUid)222     private DataUsageRequest buildRequest(Context context, DataUsageRequest request,
223                 int callingUid) {
224         // For non-NETWORK_STACK permission uid, cap the minimum threshold to a safe default to
225         // avoid too many callbacks.
226         final long thresholdInBytes = (context.checkPermission(
227                 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, Process.myPid(), callingUid)
228                 == PackageManager.PERMISSION_GRANTED ? request.thresholdInBytes
229                 : Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes));
230         if (thresholdInBytes > request.thresholdInBytes) {
231             Log.w(TAG, "Threshold was too low for " + request
232                     + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
233         }
234         return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
235                 request.template, thresholdInBytes);
236     }
237 
buildRequestInfo(DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)238     private RequestInfo buildRequestInfo(DataUsageRequest request, IUsageCallback callback,
239             int callingPid, int callingUid, @NonNull String callingPackage,
240             @NetworkStatsAccess.Level int accessLevel) {
241         if (accessLevel <= NetworkStatsAccess.Level.USER) {
242             return new UserUsageRequestInfo(this, request, callback, callingPid,
243                     callingUid, callingPackage, accessLevel);
244         } else {
245             // Safety check in case a new access level is added and we forgot to update this
246             if (accessLevel < NetworkStatsAccess.Level.DEVICESUMMARY) {
247                 throw new IllegalArgumentException(
248                         "accessLevel " + accessLevel + " is less than DEVICESUMMARY.");
249             }
250             return new NetworkUsageRequestInfo(this, request, callback, callingPid,
251                     callingUid, callingPackage, accessLevel);
252         }
253     }
254 
255     /**
256      * Tracks information relevant to a data usage observer.
257      * It will notice when the calling process dies so we can self-expire.
258      */
259     private abstract static class RequestInfo implements IBinder.DeathRecipient {
260         private final NetworkStatsObservers mStatsObserver;
261         protected final DataUsageRequest mRequest;
262         private final IUsageCallback mCallback;
263         protected final int mCallingPid;
264         protected final int mCallingUid;
265         protected final String mCallingPackage;
266         protected final @NetworkStatsAccess.Level int mAccessLevel;
267         protected NetworkStatsRecorder mRecorder;
268         protected NetworkStatsCollection mCollection;
269 
RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)270         RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
271                 IUsageCallback callback, int callingPid, int callingUid,
272                 @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
273             mStatsObserver = statsObserver;
274             mRequest = request;
275             mCallback = callback;
276             mCallingPid = callingPid;
277             mCallingUid = callingUid;
278             mCallingPackage = callingPackage;
279             mAccessLevel = accessLevel;
280 
281             try {
282                 mCallback.asBinder().linkToDeath(this, 0);
283             } catch (RemoteException e) {
284                 binderDied();
285             }
286         }
287 
288         @Override
binderDied()289         public void binderDied() {
290             if (LOGV) {
291                 Log.v(TAG, "RequestInfo binderDied(" + mRequest + ", " + mCallback + ")");
292             }
293             mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
294             callCallback(NetworkStatsManager.CALLBACK_RELEASED);
295         }
296 
297         @Override
toString()298         public String toString() {
299             return "RequestInfo from pid/uid:" + mCallingPid + "/" + mCallingUid
300                     + "(" + mCallingPackage + ")"
301                     + " for " + mRequest + " accessLevel:" + mAccessLevel;
302         }
303 
unlinkDeathRecipient()304         private void unlinkDeathRecipient() {
305             mCallback.asBinder().unlinkToDeath(this, 0);
306         }
307 
308         /**
309          * Update stats given the samples and interface to identity mappings.
310          */
updateStats(StatsContext statsContext)311         private void updateStats(StatsContext statsContext) {
312             if (mRecorder == null) {
313                 // First run; establish baseline stats
314                 resetRecorder();
315                 recordSample(statsContext);
316                 return;
317             }
318             recordSample(statsContext);
319 
320             if (checkStats()) {
321                 resetRecorder();
322                 callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
323             }
324         }
325 
callCallback(int callbackType)326         private void callCallback(int callbackType) {
327             try {
328                 if (LOGV) {
329                     Log.v(TAG, "sending notification " + callbackTypeToName(callbackType)
330                             + " for " + mRequest);
331                 }
332                 switch (callbackType) {
333                     case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
334                         mCallback.onThresholdReached(mRequest);
335                         break;
336                     case NetworkStatsManager.CALLBACK_RELEASED:
337                         mCallback.onCallbackReleased(mRequest);
338                         break;
339                 }
340             } catch (RemoteException e) {
341                 // May occur naturally in the race of binder death.
342                 Log.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
343             }
344         }
345 
resetRecorder()346         private void resetRecorder() {
347             mRecorder = new NetworkStatsRecorder();
348             mCollection = mRecorder.getSinceBoot();
349         }
350 
checkStats()351         protected abstract boolean checkStats();
352 
recordSample(StatsContext statsContext)353         protected abstract void recordSample(StatsContext statsContext);
354 
callbackTypeToName(int callbackType)355         private String callbackTypeToName(int callbackType) {
356             switch (callbackType) {
357                 case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
358                     return "LIMIT_REACHED";
359                 case NetworkStatsManager.CALLBACK_RELEASED:
360                     return "RELEASED";
361                 default:
362                     return "UNKNOWN";
363             }
364         }
365     }
366 
367     private static class NetworkUsageRequestInfo extends RequestInfo {
NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)368         NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
369                 IUsageCallback callback, int callingPid, int callingUid,
370                 @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
371             super(statsObserver, request, callback, callingPid, callingUid, callingPackage,
372                     accessLevel);
373         }
374 
375         @Override
checkStats()376         protected boolean checkStats() {
377             long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
378             if (LOGV) {
379                 Log.v(TAG, bytesSoFar + " bytes so far since notification for "
380                         + mRequest.template);
381             }
382             if (bytesSoFar > mRequest.thresholdInBytes) {
383                 return true;
384             }
385             return false;
386         }
387 
388         @Override
recordSample(StatsContext statsContext)389         protected void recordSample(StatsContext statsContext) {
390             // Recorder does not need to be locked in this context since only the handler
391             // thread will update it. We pass a null VPN array because usage is aggregated by uid
392             // for this snapshot, so VPN traffic can't be reattributed to responsible apps.
393             mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
394                     statsContext.mCurrentTime);
395         }
396 
397         /**
398          * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
399          * over all buckets, which in this case should be only one since we built it big enough
400          * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
401          */
getTotalBytesForNetwork(NetworkTemplate template)402         private long getTotalBytesForNetwork(NetworkTemplate template) {
403             NetworkStats stats = mCollection.getSummary(template,
404                     Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
405                     mAccessLevel, mCallingUid);
406             return stats.getTotalBytes();
407         }
408     }
409 
410     private static class UserUsageRequestInfo extends RequestInfo {
UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)411         UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
412                 IUsageCallback callback, int callingPid, int callingUid,
413                 @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
414             super(statsObserver, request, callback, callingPid, callingUid,
415                     callingPackage, accessLevel);
416         }
417 
418         @Override
checkStats()419         protected boolean checkStats() {
420             int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
421 
422             for (int i = 0; i < uidsToMonitor.length; i++) {
423                 long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
424                 if (bytesSoFar > mRequest.thresholdInBytes) {
425                     return true;
426                 }
427             }
428             return false;
429         }
430 
431         @Override
recordSample(StatsContext statsContext)432         protected void recordSample(StatsContext statsContext) {
433             // Recorder does not need to be locked in this context since only the handler
434             // thread will update it. We pass the VPN info so VPN traffic is reattributed to
435             // responsible apps.
436             mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
437                     statsContext.mCurrentTime);
438         }
439 
440         /**
441          * Reads all stats matching the given template and uid. Ther history will likely only
442          * contain one bucket per ident since we build it big enough that it will outlive the
443          * caller lifetime.
444          */
getTotalBytesForNetworkUid(NetworkTemplate template, int uid)445         private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
446             try {
447                 NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
448                         NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
449                         NetworkStatsHistory.FIELD_ALL,
450                         Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
451                         mAccessLevel, mCallingUid);
452                 return history.getTotalBytes();
453             } catch (SecurityException e) {
454                 if (LOGV) {
455                     Log.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
456                             + uid);
457                 }
458                 return 0;
459             }
460         }
461     }
462 
463     private static class StatsContext {
464         NetworkStats mXtSnapshot;
465         NetworkStats mUidSnapshot;
466         ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
467         ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
468         long mCurrentTime;
469 
StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, long currentTime)470         StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
471                 ArrayMap<String, NetworkIdentitySet> activeIfaces,
472                 ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
473                 long currentTime) {
474             mXtSnapshot = xtSnapshot;
475             mUidSnapshot = uidSnapshot;
476             mActiveIfaces = activeIfaces;
477             mActiveUidIfaces = activeUidIfaces;
478             mCurrentTime = currentTime;
479         }
480     }
481 
dump(IndentingPrintWriter pw)482     public void dump(IndentingPrintWriter pw) {
483         for (int i = 0; i < Math.min(mDataUsageRequests.size(), DUMP_USAGE_REQUESTS_COUNT); i++) {
484             pw.println(mDataUsageRequests.valueAt(i));
485         }
486     }
487 }
488