1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.net; 18 19 import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES; 20 21 import android.annotation.NonNull; 22 import android.app.usage.NetworkStatsManager; 23 import android.content.Context; 24 import android.content.pm.PackageManager; 25 import android.net.DataUsageRequest; 26 import android.net.NetworkIdentitySet; 27 import android.net.NetworkStack; 28 import android.net.NetworkStats; 29 import android.net.NetworkStatsAccess; 30 import android.net.NetworkStatsCollection; 31 import android.net.NetworkStatsHistory; 32 import android.net.NetworkTemplate; 33 import android.net.netstats.IUsageCallback; 34 import android.os.Handler; 35 import android.os.HandlerThread; 36 import android.os.IBinder; 37 import android.os.Looper; 38 import android.os.Message; 39 import android.os.Process; 40 import android.os.RemoteException; 41 import android.util.ArrayMap; 42 import android.util.IndentingPrintWriter; 43 import android.util.Log; 44 import android.util.SparseArray; 45 46 import com.android.internal.annotations.VisibleForTesting; 47 import com.android.net.module.util.PerUidCounter; 48 49 import java.util.concurrent.atomic.AtomicInteger; 50 51 /** 52 * Manages observers of {@link NetworkStats}. Allows observers to be notified when 53 * data usage has been reported in {@link NetworkStatsService}. An observer can set 54 * a threshold of how much data it cares about to be notified. 55 */ 56 class NetworkStatsObservers { 57 private static final String TAG = "NetworkStatsObservers"; 58 private static final boolean LOG = true; 59 private static final boolean LOGV = false; 60 61 private static final int MSG_REGISTER = 1; 62 private static final int MSG_UNREGISTER = 2; 63 private static final int MSG_UPDATE_STATS = 3; 64 65 private static final int DUMP_USAGE_REQUESTS_COUNT = 200; 66 67 // The maximum number of request allowed per uid before an exception is thrown. 68 @VisibleForTesting 69 static final int MAX_REQUESTS_PER_UID = 100; 70 71 // All access to this map must be done from the handler thread. 72 // indexed by DataUsageRequest#requestId 73 private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>(); 74 75 // Request counters per uid, this is thread safe. 76 private final PerUidCounter mDataUsageRequestsPerUid = new PerUidCounter(MAX_REQUESTS_PER_UID); 77 78 // Sequence number of DataUsageRequests 79 private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger(); 80 81 // Lazily instantiated when an observer is registered. 82 private volatile Handler mHandler; 83 84 /** 85 * Creates a wrapper that contains the caller context and a normalized request. 86 * The request should be returned to the caller app, and the wrapper should be sent to this 87 * object through #addObserver by the service handler. 88 * 89 * <p>It will register the observer asynchronously, so it is safe to call from any thread. 90 * 91 * @return the normalized request wrapped within {@link RequestInfo}. 92 */ register(@onNull Context context, @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)93 public DataUsageRequest register(@NonNull Context context, 94 @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback, 95 int callingPid, int callingUid, @NonNull String callingPackage, 96 @NetworkStatsAccess.Level int accessLevel) { 97 DataUsageRequest request = buildRequest(context, inputRequest, callingUid); 98 RequestInfo requestInfo = buildRequestInfo(request, callback, callingPid, callingUid, 99 callingPackage, accessLevel); 100 if (LOG) Log.d(TAG, "Registering observer for " + requestInfo); 101 mDataUsageRequestsPerUid.incrementCountOrThrow(callingUid); 102 103 getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo)); 104 return request; 105 } 106 107 /** 108 * Unregister a data usage observer. 109 * 110 * <p>It will unregister the observer asynchronously, so it is safe to call from any thread. 111 */ unregister(DataUsageRequest request, int callingUid)112 public void unregister(DataUsageRequest request, int callingUid) { 113 getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */, 114 request)); 115 } 116 117 /** 118 * Updates data usage statistics of registered observers and notifies if limits are reached. 119 * 120 * <p>It will update stats asynchronously, so it is safe to call from any thread. 121 */ updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, long currentTime)122 public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, 123 ArrayMap<String, NetworkIdentitySet> activeIfaces, 124 ArrayMap<String, NetworkIdentitySet> activeUidIfaces, 125 long currentTime) { 126 StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces, 127 activeUidIfaces, currentTime); 128 getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext)); 129 } 130 getHandler()131 private Handler getHandler() { 132 if (mHandler == null) { 133 synchronized (this) { 134 if (mHandler == null) { 135 if (LOGV) Log.v(TAG, "Creating handler"); 136 mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback); 137 } 138 } 139 } 140 return mHandler; 141 } 142 143 @VisibleForTesting getHandlerLooperLocked()144 protected Looper getHandlerLooperLocked() { 145 // TODO: Currently, callbacks are dispatched on this thread if the caller register 146 // callback without supplying a Handler. To ensure that the service handler thread 147 // is not blocked by client code, the observers must create their own thread. Once 148 // all callbacks are dispatched outside of the handler thread, the service handler 149 // thread can be used here. 150 HandlerThread handlerThread = new HandlerThread(TAG); 151 handlerThread.start(); 152 return handlerThread.getLooper(); 153 } 154 155 private Handler.Callback mHandlerCallback = new Handler.Callback() { 156 @Override 157 public boolean handleMessage(Message msg) { 158 switch (msg.what) { 159 case MSG_REGISTER: { 160 handleRegister((RequestInfo) msg.obj); 161 return true; 162 } 163 case MSG_UNREGISTER: { 164 handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */); 165 return true; 166 } 167 case MSG_UPDATE_STATS: { 168 handleUpdateStats((StatsContext) msg.obj); 169 return true; 170 } 171 default: { 172 return false; 173 } 174 } 175 } 176 }; 177 178 /** 179 * Adds a {@link RequestInfo} as an observer. 180 * Should only be called from the handler thread otherwise there will be a race condition 181 * on mDataUsageRequests. 182 */ handleRegister(RequestInfo requestInfo)183 private void handleRegister(RequestInfo requestInfo) { 184 mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo); 185 } 186 187 /** 188 * Removes a {@link DataUsageRequest} if the calling uid is authorized. 189 * Should only be called from the handler thread otherwise there will be a race condition 190 * on mDataUsageRequests. 191 */ handleUnregister(DataUsageRequest request, int callingUid)192 private void handleUnregister(DataUsageRequest request, int callingUid) { 193 RequestInfo requestInfo; 194 requestInfo = mDataUsageRequests.get(request.requestId); 195 if (requestInfo == null) { 196 if (LOG) Log.d(TAG, "Trying to unregister unknown request " + request); 197 return; 198 } 199 if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) { 200 Log.w(TAG, "Caller uid " + callingUid + " is not owner of " + request); 201 return; 202 } 203 204 if (LOG) Log.d(TAG, "Unregistering " + requestInfo); 205 mDataUsageRequests.remove(request.requestId); 206 mDataUsageRequestsPerUid.decrementCountOrThrow(requestInfo.mCallingUid); 207 requestInfo.unlinkDeathRecipient(); 208 requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED); 209 } 210 handleUpdateStats(StatsContext statsContext)211 private void handleUpdateStats(StatsContext statsContext) { 212 if (mDataUsageRequests.size() == 0) { 213 return; 214 } 215 216 for (int i = 0; i < mDataUsageRequests.size(); i++) { 217 RequestInfo requestInfo = mDataUsageRequests.valueAt(i); 218 requestInfo.updateStats(statsContext); 219 } 220 } 221 buildRequest(Context context, DataUsageRequest request, int callingUid)222 private DataUsageRequest buildRequest(Context context, DataUsageRequest request, 223 int callingUid) { 224 // For non-NETWORK_STACK permission uid, cap the minimum threshold to a safe default to 225 // avoid too many callbacks. 226 final long thresholdInBytes = (context.checkPermission( 227 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, Process.myPid(), callingUid) 228 == PackageManager.PERMISSION_GRANTED ? request.thresholdInBytes 229 : Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes)); 230 if (thresholdInBytes > request.thresholdInBytes) { 231 Log.w(TAG, "Threshold was too low for " + request 232 + ". Overriding to a safer default of " + thresholdInBytes + " bytes"); 233 } 234 return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(), 235 request.template, thresholdInBytes); 236 } 237 buildRequestInfo(DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)238 private RequestInfo buildRequestInfo(DataUsageRequest request, IUsageCallback callback, 239 int callingPid, int callingUid, @NonNull String callingPackage, 240 @NetworkStatsAccess.Level int accessLevel) { 241 if (accessLevel <= NetworkStatsAccess.Level.USER) { 242 return new UserUsageRequestInfo(this, request, callback, callingPid, 243 callingUid, callingPackage, accessLevel); 244 } else { 245 // Safety check in case a new access level is added and we forgot to update this 246 if (accessLevel < NetworkStatsAccess.Level.DEVICESUMMARY) { 247 throw new IllegalArgumentException( 248 "accessLevel " + accessLevel + " is less than DEVICESUMMARY."); 249 } 250 return new NetworkUsageRequestInfo(this, request, callback, callingPid, 251 callingUid, callingPackage, accessLevel); 252 } 253 } 254 255 /** 256 * Tracks information relevant to a data usage observer. 257 * It will notice when the calling process dies so we can self-expire. 258 */ 259 private abstract static class RequestInfo implements IBinder.DeathRecipient { 260 private final NetworkStatsObservers mStatsObserver; 261 protected final DataUsageRequest mRequest; 262 private final IUsageCallback mCallback; 263 protected final int mCallingPid; 264 protected final int mCallingUid; 265 protected final String mCallingPackage; 266 protected final @NetworkStatsAccess.Level int mAccessLevel; 267 protected NetworkStatsRecorder mRecorder; 268 protected NetworkStatsCollection mCollection; 269 RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)270 RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, 271 IUsageCallback callback, int callingPid, int callingUid, 272 @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { 273 mStatsObserver = statsObserver; 274 mRequest = request; 275 mCallback = callback; 276 mCallingPid = callingPid; 277 mCallingUid = callingUid; 278 mCallingPackage = callingPackage; 279 mAccessLevel = accessLevel; 280 281 try { 282 mCallback.asBinder().linkToDeath(this, 0); 283 } catch (RemoteException e) { 284 binderDied(); 285 } 286 } 287 288 @Override binderDied()289 public void binderDied() { 290 if (LOGV) { 291 Log.v(TAG, "RequestInfo binderDied(" + mRequest + ", " + mCallback + ")"); 292 } 293 mStatsObserver.unregister(mRequest, Process.SYSTEM_UID); 294 callCallback(NetworkStatsManager.CALLBACK_RELEASED); 295 } 296 297 @Override toString()298 public String toString() { 299 return "RequestInfo from pid/uid:" + mCallingPid + "/" + mCallingUid 300 + "(" + mCallingPackage + ")" 301 + " for " + mRequest + " accessLevel:" + mAccessLevel; 302 } 303 unlinkDeathRecipient()304 private void unlinkDeathRecipient() { 305 mCallback.asBinder().unlinkToDeath(this, 0); 306 } 307 308 /** 309 * Update stats given the samples and interface to identity mappings. 310 */ updateStats(StatsContext statsContext)311 private void updateStats(StatsContext statsContext) { 312 if (mRecorder == null) { 313 // First run; establish baseline stats 314 resetRecorder(); 315 recordSample(statsContext); 316 return; 317 } 318 recordSample(statsContext); 319 320 if (checkStats()) { 321 resetRecorder(); 322 callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED); 323 } 324 } 325 callCallback(int callbackType)326 private void callCallback(int callbackType) { 327 try { 328 if (LOGV) { 329 Log.v(TAG, "sending notification " + callbackTypeToName(callbackType) 330 + " for " + mRequest); 331 } 332 switch (callbackType) { 333 case NetworkStatsManager.CALLBACK_LIMIT_REACHED: 334 mCallback.onThresholdReached(mRequest); 335 break; 336 case NetworkStatsManager.CALLBACK_RELEASED: 337 mCallback.onCallbackReleased(mRequest); 338 break; 339 } 340 } catch (RemoteException e) { 341 // May occur naturally in the race of binder death. 342 Log.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest); 343 } 344 } 345 resetRecorder()346 private void resetRecorder() { 347 mRecorder = new NetworkStatsRecorder(); 348 mCollection = mRecorder.getSinceBoot(); 349 } 350 checkStats()351 protected abstract boolean checkStats(); 352 recordSample(StatsContext statsContext)353 protected abstract void recordSample(StatsContext statsContext); 354 callbackTypeToName(int callbackType)355 private String callbackTypeToName(int callbackType) { 356 switch (callbackType) { 357 case NetworkStatsManager.CALLBACK_LIMIT_REACHED: 358 return "LIMIT_REACHED"; 359 case NetworkStatsManager.CALLBACK_RELEASED: 360 return "RELEASED"; 361 default: 362 return "UNKNOWN"; 363 } 364 } 365 } 366 367 private static class NetworkUsageRequestInfo extends RequestInfo { NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)368 NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, 369 IUsageCallback callback, int callingPid, int callingUid, 370 @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { 371 super(statsObserver, request, callback, callingPid, callingUid, callingPackage, 372 accessLevel); 373 } 374 375 @Override checkStats()376 protected boolean checkStats() { 377 long bytesSoFar = getTotalBytesForNetwork(mRequest.template); 378 if (LOGV) { 379 Log.v(TAG, bytesSoFar + " bytes so far since notification for " 380 + mRequest.template); 381 } 382 if (bytesSoFar > mRequest.thresholdInBytes) { 383 return true; 384 } 385 return false; 386 } 387 388 @Override recordSample(StatsContext statsContext)389 protected void recordSample(StatsContext statsContext) { 390 // Recorder does not need to be locked in this context since only the handler 391 // thread will update it. We pass a null VPN array because usage is aggregated by uid 392 // for this snapshot, so VPN traffic can't be reattributed to responsible apps. 393 mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces, 394 statsContext.mCurrentTime); 395 } 396 397 /** 398 * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate 399 * over all buckets, which in this case should be only one since we built it big enough 400 * that it will outlive the caller. If it doesn't, then there will be multiple buckets. 401 */ getTotalBytesForNetwork(NetworkTemplate template)402 private long getTotalBytesForNetwork(NetworkTemplate template) { 403 NetworkStats stats = mCollection.getSummary(template, 404 Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */, 405 mAccessLevel, mCallingUid); 406 return stats.getTotalBytes(); 407 } 408 } 409 410 private static class UserUsageRequestInfo extends RequestInfo { UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel)411 UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, 412 IUsageCallback callback, int callingPid, int callingUid, 413 @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { 414 super(statsObserver, request, callback, callingPid, callingUid, 415 callingPackage, accessLevel); 416 } 417 418 @Override checkStats()419 protected boolean checkStats() { 420 int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid); 421 422 for (int i = 0; i < uidsToMonitor.length; i++) { 423 long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]); 424 if (bytesSoFar > mRequest.thresholdInBytes) { 425 return true; 426 } 427 } 428 return false; 429 } 430 431 @Override recordSample(StatsContext statsContext)432 protected void recordSample(StatsContext statsContext) { 433 // Recorder does not need to be locked in this context since only the handler 434 // thread will update it. We pass the VPN info so VPN traffic is reattributed to 435 // responsible apps. 436 mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces, 437 statsContext.mCurrentTime); 438 } 439 440 /** 441 * Reads all stats matching the given template and uid. Ther history will likely only 442 * contain one bucket per ident since we build it big enough that it will outlive the 443 * caller lifetime. 444 */ getTotalBytesForNetworkUid(NetworkTemplate template, int uid)445 private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) { 446 try { 447 NetworkStatsHistory history = mCollection.getHistory(template, null, uid, 448 NetworkStats.SET_ALL, NetworkStats.TAG_NONE, 449 NetworkStatsHistory.FIELD_ALL, 450 Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */, 451 mAccessLevel, mCallingUid); 452 return history.getTotalBytes(); 453 } catch (SecurityException e) { 454 if (LOGV) { 455 Log.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid " 456 + uid); 457 } 458 return 0; 459 } 460 } 461 } 462 463 private static class StatsContext { 464 NetworkStats mXtSnapshot; 465 NetworkStats mUidSnapshot; 466 ArrayMap<String, NetworkIdentitySet> mActiveIfaces; 467 ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces; 468 long mCurrentTime; 469 StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, long currentTime)470 StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, 471 ArrayMap<String, NetworkIdentitySet> activeIfaces, 472 ArrayMap<String, NetworkIdentitySet> activeUidIfaces, 473 long currentTime) { 474 mXtSnapshot = xtSnapshot; 475 mUidSnapshot = uidSnapshot; 476 mActiveIfaces = activeIfaces; 477 mActiveUidIfaces = activeUidIfaces; 478 mCurrentTime = currentTime; 479 } 480 } 481 dump(IndentingPrintWriter pw)482 public void dump(IndentingPrintWriter pw) { 483 for (int i = 0; i < Math.min(mDataUsageRequests.size(), DUMP_USAGE_REQUESTS_COUNT); i++) { 484 pw.println(mDataUsageRequests.valueAt(i)); 485 } 486 } 487 } 488