1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.net;
18 
19 import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;
20 
21 import static com.android.internal.util.Preconditions.checkArgument;
22 
23 import android.app.usage.NetworkStatsManager;
24 import android.net.DataUsageRequest;
25 import android.net.NetworkStats;
26 import android.net.NetworkStatsHistory;
27 import android.net.NetworkTemplate;
28 import android.os.Bundle;
29 import android.os.Handler;
30 import android.os.HandlerThread;
31 import android.os.IBinder;
32 import android.os.Looper;
33 import android.os.Message;
34 import android.os.Messenger;
35 import android.os.Process;
36 import android.os.RemoteException;
37 import android.util.ArrayMap;
38 import android.util.Slog;
39 import android.util.SparseArray;
40 
41 import com.android.internal.annotations.VisibleForTesting;
42 import com.android.internal.net.VpnInfo;
43 
44 import java.util.concurrent.atomic.AtomicInteger;
45 
46 /**
47  * Manages observers of {@link NetworkStats}. Allows observers to be notified when
48  * data usage has been reported in {@link NetworkStatsService}. An observer can set
49  * a threshold of how much data it cares about to be notified.
50  */
51 class NetworkStatsObservers {
52     private static final String TAG = "NetworkStatsObservers";
53     private static final boolean LOGV = false;
54 
55     private static final int MSG_REGISTER = 1;
56     private static final int MSG_UNREGISTER = 2;
57     private static final int MSG_UPDATE_STATS = 3;
58 
59     // All access to this map must be done from the handler thread.
60     // indexed by DataUsageRequest#requestId
61     private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
62 
63     // Sequence number of DataUsageRequests
64     private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
65 
66     // Lazily instantiated when an observer is registered.
67     private volatile Handler mHandler;
68 
69     /**
70      * Creates a wrapper that contains the caller context and a normalized request.
71      * The request should be returned to the caller app, and the wrapper should be sent to this
72      * object through #addObserver by the service handler.
73      *
74      * <p>It will register the observer asynchronously, so it is safe to call from any thread.
75      *
76      * @return the normalized request wrapped within {@link RequestInfo}.
77      */
register(DataUsageRequest inputRequest, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)78     public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
79                 IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
80         DataUsageRequest request = buildRequest(inputRequest);
81         RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
82                 accessLevel);
83 
84         if (LOGV) Slog.v(TAG, "Registering observer for " + request);
85         getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
86         return request;
87     }
88 
89     /**
90      * Unregister a data usage observer.
91      *
92      * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
93      */
unregister(DataUsageRequest request, int callingUid)94     public void unregister(DataUsageRequest request, int callingUid) {
95         getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
96                 request));
97     }
98 
99     /**
100      * Updates data usage statistics of registered observers and notifies if limits are reached.
101      *
102      * <p>It will update stats asynchronously, so it is safe to call from any thread.
103      */
updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, VpnInfo[] vpnArray, long currentTime)104     public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
105                 ArrayMap<String, NetworkIdentitySet> activeIfaces,
106                 ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
107                 VpnInfo[] vpnArray, long currentTime) {
108         StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
109                 activeUidIfaces, vpnArray, currentTime);
110         getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
111     }
112 
getHandler()113     private Handler getHandler() {
114         if (mHandler == null) {
115             synchronized (this) {
116                 if (mHandler == null) {
117                     if (LOGV) Slog.v(TAG, "Creating handler");
118                     mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
119                 }
120             }
121         }
122         return mHandler;
123     }
124 
125     @VisibleForTesting
getHandlerLooperLocked()126     protected Looper getHandlerLooperLocked() {
127         HandlerThread handlerThread = new HandlerThread(TAG);
128         handlerThread.start();
129         return handlerThread.getLooper();
130     }
131 
132     private Handler.Callback mHandlerCallback = new Handler.Callback() {
133         @Override
134         public boolean handleMessage(Message msg) {
135             switch (msg.what) {
136                 case MSG_REGISTER: {
137                     handleRegister((RequestInfo) msg.obj);
138                     return true;
139                 }
140                 case MSG_UNREGISTER: {
141                     handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
142                     return true;
143                 }
144                 case MSG_UPDATE_STATS: {
145                     handleUpdateStats((StatsContext) msg.obj);
146                     return true;
147                 }
148                 default: {
149                     return false;
150                 }
151             }
152         }
153     };
154 
155     /**
156      * Adds a {@link RequestInfo} as an observer.
157      * Should only be called from the handler thread otherwise there will be a race condition
158      * on mDataUsageRequests.
159      */
handleRegister(RequestInfo requestInfo)160     private void handleRegister(RequestInfo requestInfo) {
161         mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
162     }
163 
164     /**
165      * Removes a {@link DataUsageRequest} if the calling uid is authorized.
166      * Should only be called from the handler thread otherwise there will be a race condition
167      * on mDataUsageRequests.
168      */
handleUnregister(DataUsageRequest request, int callingUid)169     private void handleUnregister(DataUsageRequest request, int callingUid) {
170         RequestInfo requestInfo;
171         requestInfo = mDataUsageRequests.get(request.requestId);
172         if (requestInfo == null) {
173             if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
174             return;
175         }
176         if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
177             Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
178             return;
179         }
180 
181         if (LOGV) Slog.v(TAG, "Unregistering " + request);
182         mDataUsageRequests.remove(request.requestId);
183         requestInfo.unlinkDeathRecipient();
184         requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
185     }
186 
handleUpdateStats(StatsContext statsContext)187     private void handleUpdateStats(StatsContext statsContext) {
188         if (mDataUsageRequests.size() == 0) {
189             return;
190         }
191 
192         for (int i = 0; i < mDataUsageRequests.size(); i++) {
193             RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
194             requestInfo.updateStats(statsContext);
195         }
196     }
197 
buildRequest(DataUsageRequest request)198     private DataUsageRequest buildRequest(DataUsageRequest request) {
199         // Cap the minimum threshold to a safe default to avoid too many callbacks
200         long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
201         if (thresholdInBytes < request.thresholdInBytes) {
202             Slog.w(TAG, "Threshold was too low for " + request
203                     + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
204         }
205         return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
206                 request.template, thresholdInBytes);
207     }
208 
buildRequestInfo(DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)209     private RequestInfo buildRequestInfo(DataUsageRequest request,
210                 Messenger messenger, IBinder binder, int callingUid,
211                 @NetworkStatsAccess.Level int accessLevel) {
212         if (accessLevel <= NetworkStatsAccess.Level.USER) {
213             return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
214                     accessLevel);
215         } else {
216             // Safety check in case a new access level is added and we forgot to update this
217             checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
218             return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
219                     accessLevel);
220         }
221     }
222 
223     /**
224      * Tracks information relevant to a data usage observer.
225      * It will notice when the calling process dies so we can self-expire.
226      */
227     private abstract static class RequestInfo implements IBinder.DeathRecipient {
228         private final NetworkStatsObservers mStatsObserver;
229         protected final DataUsageRequest mRequest;
230         private final Messenger mMessenger;
231         private final IBinder mBinder;
232         protected final int mCallingUid;
233         protected final @NetworkStatsAccess.Level int mAccessLevel;
234         protected NetworkStatsRecorder mRecorder;
235         protected NetworkStatsCollection mCollection;
236 
RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)237         RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
238                     Messenger messenger, IBinder binder, int callingUid,
239                     @NetworkStatsAccess.Level int accessLevel) {
240             mStatsObserver = statsObserver;
241             mRequest = request;
242             mMessenger = messenger;
243             mBinder = binder;
244             mCallingUid = callingUid;
245             mAccessLevel = accessLevel;
246 
247             try {
248                 mBinder.linkToDeath(this, 0);
249             } catch (RemoteException e) {
250                 binderDied();
251             }
252         }
253 
254         @Override
binderDied()255         public void binderDied() {
256             if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
257                     + mRequest + ", " + mBinder + ")");
258             mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
259             callCallback(NetworkStatsManager.CALLBACK_RELEASED);
260         }
261 
262         @Override
toString()263         public String toString() {
264             return "RequestInfo from uid:" + mCallingUid
265                     + " for " + mRequest + " accessLevel:" + mAccessLevel;
266         }
267 
unlinkDeathRecipient()268         private void unlinkDeathRecipient() {
269             if (mBinder != null) {
270                 mBinder.unlinkToDeath(this, 0);
271             }
272         }
273 
274         /**
275          * Update stats given the samples and interface to identity mappings.
276          */
updateStats(StatsContext statsContext)277         private void updateStats(StatsContext statsContext) {
278             if (mRecorder == null) {
279                 // First run; establish baseline stats
280                 resetRecorder();
281                 recordSample(statsContext);
282                 return;
283             }
284             recordSample(statsContext);
285 
286             if (checkStats()) {
287                 resetRecorder();
288                 callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
289             }
290         }
291 
callCallback(int callbackType)292         private void callCallback(int callbackType) {
293             Bundle bundle = new Bundle();
294             bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
295             Message msg = Message.obtain();
296             msg.what = callbackType;
297             msg.setData(bundle);
298             try {
299                 if (LOGV) {
300                     Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
301                             + " for " + mRequest);
302                 }
303                 mMessenger.send(msg);
304             } catch (RemoteException e) {
305                 // May occur naturally in the race of binder death.
306                 Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
307             }
308         }
309 
resetRecorder()310         private void resetRecorder() {
311             mRecorder = new NetworkStatsRecorder();
312             mCollection = mRecorder.getSinceBoot();
313         }
314 
checkStats()315         protected abstract boolean checkStats();
316 
recordSample(StatsContext statsContext)317         protected abstract void recordSample(StatsContext statsContext);
318 
callbackTypeToName(int callbackType)319         private String callbackTypeToName(int callbackType) {
320             switch (callbackType) {
321                 case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
322                     return "LIMIT_REACHED";
323                 case NetworkStatsManager.CALLBACK_RELEASED:
324                     return "RELEASED";
325                 default:
326                     return "UNKNOWN";
327             }
328         }
329     }
330 
331     private static class NetworkUsageRequestInfo extends RequestInfo {
NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)332         NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
333                     Messenger messenger, IBinder binder, int callingUid,
334                     @NetworkStatsAccess.Level int accessLevel) {
335             super(statsObserver, request, messenger, binder, callingUid, accessLevel);
336         }
337 
338         @Override
checkStats()339         protected boolean checkStats() {
340             long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
341             if (LOGV) {
342                 Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
343                         + mRequest.template);
344             }
345             if (bytesSoFar > mRequest.thresholdInBytes) {
346                 return true;
347             }
348             return false;
349         }
350 
351         @Override
recordSample(StatsContext statsContext)352         protected void recordSample(StatsContext statsContext) {
353             // Recorder does not need to be locked in this context since only the handler
354             // thread will update it. We pass a null VPN array because usage is aggregated by uid
355             // for this snapshot, so VPN traffic can't be reattributed to responsible apps.
356             mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
357                     null /* vpnArray */, statsContext.mCurrentTime);
358         }
359 
360         /**
361          * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
362          * over all buckets, which in this case should be only one since we built it big enough
363          * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
364          */
getTotalBytesForNetwork(NetworkTemplate template)365         private long getTotalBytesForNetwork(NetworkTemplate template) {
366             NetworkStats stats = mCollection.getSummary(template,
367                     Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
368                     mAccessLevel, mCallingUid);
369             return stats.getTotalBytes();
370         }
371     }
372 
373     private static class UserUsageRequestInfo extends RequestInfo {
UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)374         UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
375                     Messenger messenger, IBinder binder, int callingUid,
376                     @NetworkStatsAccess.Level int accessLevel) {
377             super(statsObserver, request, messenger, binder, callingUid, accessLevel);
378         }
379 
380         @Override
checkStats()381         protected boolean checkStats() {
382             int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
383 
384             for (int i = 0; i < uidsToMonitor.length; i++) {
385                 long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
386                 if (bytesSoFar > mRequest.thresholdInBytes) {
387                     return true;
388                 }
389             }
390             return false;
391         }
392 
393         @Override
recordSample(StatsContext statsContext)394         protected void recordSample(StatsContext statsContext) {
395             // Recorder does not need to be locked in this context since only the handler
396             // thread will update it. We pass the VPN info so VPN traffic is reattributed to
397             // responsible apps.
398             mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
399                     statsContext.mVpnArray, statsContext.mCurrentTime);
400         }
401 
402         /**
403          * Reads all stats matching the given template and uid. Ther history will likely only
404          * contain one bucket per ident since we build it big enough that it will outlive the
405          * caller lifetime.
406          */
getTotalBytesForNetworkUid(NetworkTemplate template, int uid)407         private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
408             try {
409                 NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
410                         NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
411                         NetworkStatsHistory.FIELD_ALL,
412                         Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
413                         mAccessLevel, mCallingUid);
414                 return history.getTotalBytes();
415             } catch (SecurityException e) {
416                 if (LOGV) {
417                     Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
418                             + uid);
419                 }
420                 return 0;
421             }
422         }
423     }
424 
425     private static class StatsContext {
426         NetworkStats mXtSnapshot;
427         NetworkStats mUidSnapshot;
428         ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
429         ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
430         VpnInfo[] mVpnArray;
431         long mCurrentTime;
432 
StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, VpnInfo[] vpnArray, long currentTime)433         StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
434                 ArrayMap<String, NetworkIdentitySet> activeIfaces,
435                 ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
436                 VpnInfo[] vpnArray, long currentTime) {
437             mXtSnapshot = xtSnapshot;
438             mUidSnapshot = uidSnapshot;
439             mActiveIfaces = activeIfaces;
440             mActiveUidIfaces = activeUidIfaces;
441             mVpnArray = vpnArray;
442             mCurrentTime = currentTime;
443         }
444     }
445 }
446