1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.healthconnect.storage.request; 18 19 import static android.health.connect.datatypes.AggregationType.AVG; 20 import static android.health.connect.datatypes.AggregationType.COUNT; 21 import static android.health.connect.datatypes.AggregationType.MAX; 22 import static android.health.connect.datatypes.AggregationType.MIN; 23 import static android.health.connect.datatypes.AggregationType.SUM; 24 25 import static com.android.server.healthconnect.storage.datatypehelpers.RecordHelper.APP_INFO_ID_COLUMN_NAME; 26 27 import android.annotation.NonNull; 28 import android.database.Cursor; 29 import android.health.connect.AggregateResult; 30 import android.health.connect.Constants; 31 import android.health.connect.LocalTimeRangeFilter; 32 import android.health.connect.TimeRangeFilter; 33 import android.health.connect.TimeRangeFilterHelper; 34 import android.health.connect.datatypes.AggregationType; 35 import android.util.ArrayMap; 36 import android.util.Pair; 37 import android.util.Slog; 38 39 import com.android.server.healthconnect.storage.TransactionManager; 40 import com.android.server.healthconnect.storage.datatypehelpers.AppInfoHelper; 41 import com.android.server.healthconnect.storage.datatypehelpers.RecordHelper; 42 import com.android.server.healthconnect.storage.datatypehelpers.aggregation.PriorityRecordsAggregator; 43 import com.android.server.healthconnect.storage.utils.OrderByClause; 44 import com.android.server.healthconnect.storage.utils.SqlJoin; 45 import com.android.server.healthconnect.storage.utils.StorageUtils; 46 import com.android.server.healthconnect.storage.utils.WhereClauses; 47 48 import java.time.Duration; 49 import java.time.LocalDateTime; 50 import java.time.Period; 51 import java.util.ArrayList; 52 import java.util.List; 53 import java.util.Map; 54 55 /** 56 * A request for {@link TransactionManager} to query the DB for aggregation results 57 * 58 * @hide 59 */ 60 public class AggregateTableRequest { 61 private static final String TAG = "HealthConnectAggregate"; 62 private static final String GROUP_BY_COLUMN_NAME = "category"; 63 64 private static final int MAX_NUMBER_OF_GROUPS = Constants.MAXIMUM_PAGE_SIZE; 65 66 private final String mTableName; 67 private final List<String> mColumnNamesToAggregate; 68 private final AggregationType<?> mAggregationType; 69 private final RecordHelper<?> mRecordHelper; 70 private final Map<Integer, AggregateResult<?>> mAggregateResults = new ArrayMap<>(); 71 72 /** 73 * Represents "start time" for interval record, and "time" for instant record. 74 * 75 * <p>{@link #mUseLocalTime} is already taken into account when this field is set, meaning if 76 * {@link #mUseLocalTime} is {@code true}, then this field represent local time, otherwise 77 * physical time. 78 */ 79 private final String mTimeColumnName; 80 81 private final WhereClauses mWhereClauses; 82 private final SqlJoin mSqlJoin; 83 private String mGroupByColumnName; 84 private int mGroupBySize = 1; 85 private final List<String> mAdditionalColumnsToFetch; 86 private final AggregateParams.PriorityAggregationExtraParams mPriorityParams; 87 private final boolean mUseLocalTime; 88 private List<Long> mTimeSplits; 89 90 @SuppressWarnings("NullAway.Init") // TODO(b/317029272): fix this suppression AggregateTableRequest( AggregateParams params, AggregationType<?> aggregationType, RecordHelper<?> recordHelper, WhereClauses whereClauses, boolean useLocalTime)91 public AggregateTableRequest( 92 AggregateParams params, 93 AggregationType<?> aggregationType, 94 RecordHelper<?> recordHelper, 95 WhereClauses whereClauses, 96 boolean useLocalTime) { 97 mTableName = params.getTableName(); 98 mColumnNamesToAggregate = params.getColumnsToFetch(); 99 mTimeColumnName = params.getTimeColumnName(); 100 mAggregationType = aggregationType; 101 mRecordHelper = recordHelper; 102 mSqlJoin = params.getJoin(); 103 mPriorityParams = params.getPriorityAggregationExtraParams(); 104 mWhereClauses = whereClauses; 105 mAdditionalColumnsToFetch = new ArrayList<>(); 106 mAdditionalColumnsToFetch.add(params.getTimeOffsetColumnName()); 107 mAdditionalColumnsToFetch.add(mTimeColumnName); 108 String endTimeColumnName = params.getExtraTimeColumnName(); 109 if (endTimeColumnName != null) { 110 mAdditionalColumnsToFetch.add(endTimeColumnName); 111 } 112 mUseLocalTime = useLocalTime; 113 } 114 115 /** 116 * @return {@link AggregationType} for this request 117 */ getAggregationType()118 public AggregationType<?> getAggregationType() { 119 return mAggregationType; 120 } 121 122 /** 123 * @return {@link RecordHelper} for this request 124 */ getRecordHelper()125 public RecordHelper<?> getRecordHelper() { 126 return mRecordHelper; 127 } 128 129 /** 130 * @return results fetched after performing aggregate operation for this class. 131 * <p>Note: Only available after the call to {@link 132 * TransactionManager#populateWithAggregation} has been made 133 */ getAggregateResults()134 public List<AggregateResult<?>> getAggregateResults() { 135 List<AggregateResult<?>> aggregateResults = new ArrayList<>(mGroupBySize); 136 for (int i = 0; i < mGroupBySize; i++) { 137 aggregateResults.add(mAggregateResults.get(i)); 138 } 139 140 return aggregateResults; 141 } 142 143 /** Returns SQL statement to get data origins for the aggregation operation */ getCommandToFetchAggregateMetadata()144 public String getCommandToFetchAggregateMetadata() { 145 final StringBuilder builder = new StringBuilder("SELECT DISTINCT "); 146 builder.append(APP_INFO_ID_COLUMN_NAME).append(", "); 147 return appendAggregateCommand(builder, /* isMetadata= */ true); 148 } 149 150 /** Returns name of the main time column (start time for Interval, time for Instant records) */ getTimeColumnName()151 public String getTimeColumnName() { 152 return mTimeColumnName; 153 } 154 155 /** Returns whether request is using local time instead of physical one. */ getUseLocalTime()156 public boolean getUseLocalTime() { 157 return mUseLocalTime; 158 } 159 160 /** Returns SQL statement to perform aggregation operation */ 161 @NonNull getAggregationCommand()162 public String getAggregationCommand() { 163 final StringBuilder builder = new StringBuilder("SELECT "); 164 String aggCommand; 165 boolean usingPriority = 166 StorageUtils.supportsPriority( 167 mRecordHelper.getRecordIdentifier(), 168 mAggregationType.getAggregateOperationType()) 169 || StorageUtils.isDerivedType(mRecordHelper.getRecordIdentifier()); 170 if (usingPriority) { 171 for (String columnName : mColumnNamesToAggregate) { 172 builder.append(columnName).append(", "); 173 } 174 } else { 175 aggCommand = getSqlCommandFor(mAggregationType.getAggregateOperationType()); 176 177 for (String columnName : mColumnNamesToAggregate) { 178 builder.append(aggCommand) 179 .append("(") 180 .append(columnName) 181 .append(")") 182 .append(" as ") 183 .append(columnName) 184 .append(", "); 185 } 186 } 187 188 if (mAdditionalColumnsToFetch != null) { 189 for (String additionalColumnToFetch : mAdditionalColumnsToFetch) { 190 builder.append(additionalColumnToFetch).append(", "); 191 } 192 } 193 194 return appendAggregateCommand(builder, usingPriority); 195 } 196 197 /** Sets time filter for table request. */ setTimeFilter(long startTime, long endTime)198 public AggregateTableRequest setTimeFilter(long startTime, long endTime) { 199 // Return if the params will result in no impact on the query 200 if (startTime < 0 || endTime < startTime) { 201 return this; 202 } 203 204 mTimeSplits = List.of(startTime, endTime); 205 return this; 206 } 207 208 /** Sets group by fields. */ setGroupBy( String columnName, Period period, Duration duration, TimeRangeFilter timeRangeFilter)209 public void setGroupBy( 210 String columnName, Period period, Duration duration, TimeRangeFilter timeRangeFilter) { 211 mGroupByColumnName = columnName; 212 if (period != null) { 213 mTimeSplits = getGroupSplitsForPeriod(timeRangeFilter, period); 214 } else if (duration != null) { 215 mTimeSplits = getGroupSplitsForDuration(timeRangeFilter, duration); 216 } else { 217 throw new IllegalArgumentException( 218 "Either aggregation period or duration should be not null"); 219 } 220 mGroupBySize = mTimeSplits.size() - 1; 221 222 if (Constants.DEBUG) { 223 Slog.d( 224 TAG, 225 "Group aggregation splits: " 226 + mTimeSplits 227 + " number of groups: " 228 + mGroupBySize); 229 } 230 } 231 onResultsFetched(Cursor cursor, Cursor metaDataCursor)232 public void onResultsFetched(Cursor cursor, Cursor metaDataCursor) { 233 if (StorageUtils.isDerivedType(mRecordHelper.getRecordIdentifier())) { 234 deriveAggregate(cursor); 235 } else if (StorageUtils.supportsPriority( 236 mRecordHelper.getRecordIdentifier(), 237 mAggregationType.getAggregateOperationType())) { 238 processPriorityRequest(cursor); 239 } else { 240 processNoPrioritiesRequest(cursor); 241 } 242 243 updateResultWithDataOriginPackageNames(metaDataCursor); 244 } 245 processPriorityRequest(Cursor cursor)246 private void processPriorityRequest(Cursor cursor) { 247 List<Long> priorityList = 248 StorageUtils.getAppIdPriorityList(mRecordHelper.getRecordIdentifier()); 249 PriorityRecordsAggregator aggregator = 250 new PriorityRecordsAggregator( 251 mTimeSplits, 252 priorityList, 253 mAggregationType.getAggregationTypeIdentifier(), 254 mPriorityParams, 255 mUseLocalTime); 256 aggregator.calculateAggregation(cursor); 257 AggregateResult<?> result; 258 for (int groupNumber = 0; groupNumber < mGroupBySize; groupNumber++) { 259 if (aggregator.getResultForGroup(groupNumber) == null) { 260 continue; 261 } 262 263 if (mAggregationType.getAggregateResultClass() == Long.class) { 264 result = 265 new AggregateResult<>( 266 aggregator.getResultForGroup(groupNumber).longValue()); 267 } else { 268 result = new AggregateResult<>(aggregator.getResultForGroup(groupNumber)); 269 } 270 mAggregateResults.put( 271 groupNumber, 272 result.setZoneOffset(aggregator.getZoneOffsetForGroup(groupNumber))); 273 } 274 275 if (Constants.DEBUG) { 276 Slog.d(TAG, "Priority aggregation result: " + mAggregateResults); 277 } 278 } 279 processNoPrioritiesRequest(Cursor cursor)280 private void processNoPrioritiesRequest(Cursor cursor) { 281 while (cursor.moveToNext()) { 282 mAggregateResults.put( 283 StorageUtils.getCursorInt(cursor, GROUP_BY_COLUMN_NAME), 284 mRecordHelper.getAggregateResult(cursor, mAggregationType)); 285 } 286 } 287 288 @SuppressWarnings("NullAway") // TODO(b/317029272): fix this suppression getSqlCommandFor(@ggregationType.AggregateOperationType int type)289 private static String getSqlCommandFor(@AggregationType.AggregateOperationType int type) { 290 return switch (type) { 291 case MAX -> "MAX"; 292 case MIN -> "MIN"; 293 case AVG -> "AVG"; 294 case SUM -> "SUM"; 295 case COUNT -> "COUNT"; 296 default -> null; 297 }; 298 } 299 appendAggregateCommand(StringBuilder builder, boolean isMetadata)300 private String appendAggregateCommand(StringBuilder builder, boolean isMetadata) { 301 boolean useGroupBy = mGroupByColumnName != null && !isMetadata; 302 if (useGroupBy) { 303 builder.append(" CASE "); 304 int groupByIndex = 0; 305 for (int i = 0; i < mTimeSplits.size() - 1; i++) { 306 builder.append(" WHEN ") 307 .append(mTimeColumnName) 308 .append(" >= ") 309 .append(mTimeSplits.get(i)) 310 .append(" AND ") 311 .append(mTimeColumnName) 312 .append(" < ") 313 .append(mTimeSplits.get(i + 1)) 314 .append(" THEN ") 315 .append(groupByIndex++); 316 } 317 builder.append(" END " + GROUP_BY_COLUMN_NAME + " "); 318 } else { 319 builder.setLength(builder.length() - 2); // Remove the last 2 char i.e. ", " 320 } 321 322 builder.append(" FROM ").append(mTableName); 323 if (mSqlJoin != null) { 324 builder.append(mSqlJoin.getJoinCommand()); 325 } 326 327 builder.append(mWhereClauses.get(/* withWhereKeyword= */ true)); 328 329 if (useGroupBy) { 330 builder.append(" GROUP BY " + GROUP_BY_COLUMN_NAME); 331 } 332 333 OrderByClause orderByClause = new OrderByClause(); 334 orderByClause.addOrderByClause(mTimeColumnName, true); 335 builder.append(orderByClause.getOrderBy()); 336 337 if (Constants.DEBUG) { 338 Slog.d(TAG, "Aggregation origin query: " + builder); 339 } 340 341 return builder.toString(); 342 } 343 344 @SuppressWarnings("NullAway") // TODO(b/317029272): fix this suppression updateResultWithDataOriginPackageNames(Cursor metaDataCursor)345 private void updateResultWithDataOriginPackageNames(Cursor metaDataCursor) { 346 List<Long> packageIds = new ArrayList<>(); 347 while (metaDataCursor.moveToNext()) { 348 packageIds.add(StorageUtils.getCursorLong(metaDataCursor, APP_INFO_ID_COLUMN_NAME)); 349 } 350 List<String> packageNames = AppInfoHelper.getInstance().getPackageNames(packageIds); 351 352 mAggregateResults.replaceAll( 353 (n, v) -> mAggregateResults.get(n).setDataOrigins(packageNames)); 354 } 355 getGroupSplitIntervals()356 public List<Pair<Long, Long>> getGroupSplitIntervals() { 357 List<Pair<Long, Long>> groupIntervals = new ArrayList<>(); 358 long previous = mTimeSplits.get(0); 359 for (int i = 1; i < mTimeSplits.size(); i++) { 360 Pair<Long, Long> pair = new Pair<>(previous, mTimeSplits.get(i)); 361 groupIntervals.add(pair); 362 previous = mTimeSplits.get(i); 363 } 364 365 return groupIntervals; 366 } 367 getGroupSplitsForPeriod(TimeRangeFilter timeFilter, Period period)368 private List<Long> getGroupSplitsForPeriod(TimeRangeFilter timeFilter, Period period) { 369 LocalDateTime filterStart = ((LocalTimeRangeFilter) timeFilter).getStartTime(); 370 LocalDateTime filterEnd = ((LocalTimeRangeFilter) timeFilter).getEndTime(); 371 372 List<Long> splits = new ArrayList<>(); 373 splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(filterStart)); 374 375 LocalDateTime currentEnd = filterStart.plus(period); 376 while (!currentEnd.isAfter(filterEnd)) { 377 splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(currentEnd)); 378 currentEnd = currentEnd.plus(period); 379 380 if (splits.size() > MAX_NUMBER_OF_GROUPS) { 381 throw new IllegalArgumentException( 382 "Number of groups must not exceed " + MAX_NUMBER_OF_GROUPS); 383 } 384 } 385 386 // If the last group doesn't fit the rest of the window, we cut it up to filterEnd 387 if (splits.get(splits.size() - 1) < TimeRangeFilterHelper.getMillisOfLocalTime(filterEnd)) { 388 splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(filterEnd)); 389 } 390 return splits; 391 } 392 getGroupSplitsForDuration( TimeRangeFilter timeRangeFilter, Duration duration)393 private List<Long> getGroupSplitsForDuration( 394 TimeRangeFilter timeRangeFilter, Duration duration) { 395 long groupByStart = TimeRangeFilterHelper.getFilterStartTimeMillis(timeRangeFilter); 396 long groupByEnd = TimeRangeFilterHelper.getFilterEndTimeMillis(timeRangeFilter); 397 long groupDurationMillis = duration.toMillis(); 398 399 if ((groupByEnd - groupByStart) / groupDurationMillis > MAX_NUMBER_OF_GROUPS) { 400 throw new IllegalArgumentException( 401 "Number of buckets must not exceed " + MAX_NUMBER_OF_GROUPS); 402 } 403 404 List<Long> splits = new ArrayList<>(); 405 splits.add(groupByStart); 406 long currentEnd = groupByStart + groupDurationMillis; 407 while (currentEnd <= groupByEnd) { 408 splits.add(currentEnd); 409 currentEnd += groupDurationMillis; 410 } 411 412 // If the last group doesn't fit the rest of the window, we cut it up to filterEnd 413 if (splits.get(splits.size() - 1) < groupByEnd) { 414 splits.add(groupByEnd); 415 } 416 return splits; 417 } 418 deriveAggregate(Cursor cursor)419 private void deriveAggregate(Cursor cursor) { 420 double[] derivedAggregateArray = mRecordHelper.deriveAggregate(cursor, this); 421 int index = 0; 422 cursor.moveToFirst(); 423 for (double aggregate : derivedAggregateArray) { 424 mAggregateResults.put( 425 index, mRecordHelper.getAggregateResult(cursor, mAggregationType, aggregate)); 426 index++; 427 } 428 } 429 } 430