1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.healthconnect.storage.request;
18 
19 import static android.health.connect.datatypes.AggregationType.AVG;
20 import static android.health.connect.datatypes.AggregationType.COUNT;
21 import static android.health.connect.datatypes.AggregationType.MAX;
22 import static android.health.connect.datatypes.AggregationType.MIN;
23 import static android.health.connect.datatypes.AggregationType.SUM;
24 
25 import static com.android.server.healthconnect.storage.datatypehelpers.RecordHelper.APP_INFO_ID_COLUMN_NAME;
26 
27 import android.annotation.NonNull;
28 import android.database.Cursor;
29 import android.health.connect.AggregateResult;
30 import android.health.connect.Constants;
31 import android.health.connect.LocalTimeRangeFilter;
32 import android.health.connect.TimeRangeFilter;
33 import android.health.connect.TimeRangeFilterHelper;
34 import android.health.connect.datatypes.AggregationType;
35 import android.util.ArrayMap;
36 import android.util.Pair;
37 import android.util.Slog;
38 
39 import com.android.server.healthconnect.storage.TransactionManager;
40 import com.android.server.healthconnect.storage.datatypehelpers.AppInfoHelper;
41 import com.android.server.healthconnect.storage.datatypehelpers.RecordHelper;
42 import com.android.server.healthconnect.storage.datatypehelpers.aggregation.PriorityRecordsAggregator;
43 import com.android.server.healthconnect.storage.utils.OrderByClause;
44 import com.android.server.healthconnect.storage.utils.SqlJoin;
45 import com.android.server.healthconnect.storage.utils.StorageUtils;
46 import com.android.server.healthconnect.storage.utils.WhereClauses;
47 
48 import java.time.Duration;
49 import java.time.LocalDateTime;
50 import java.time.Period;
51 import java.util.ArrayList;
52 import java.util.List;
53 import java.util.Map;
54 
55 /**
56  * A request for {@link TransactionManager} to query the DB for aggregation results
57  *
58  * @hide
59  */
60 public class AggregateTableRequest {
61     private static final String TAG = "HealthConnectAggregate";
62     private static final String GROUP_BY_COLUMN_NAME = "category";
63 
64     private static final int MAX_NUMBER_OF_GROUPS = Constants.MAXIMUM_PAGE_SIZE;
65 
66     private final String mTableName;
67     private final List<String> mColumnNamesToAggregate;
68     private final AggregationType<?> mAggregationType;
69     private final RecordHelper<?> mRecordHelper;
70     private final Map<Integer, AggregateResult<?>> mAggregateResults = new ArrayMap<>();
71 
72     /**
73      * Represents "start time" for interval record, and "time" for instant record.
74      *
75      * <p>{@link #mUseLocalTime} is already taken into account when this field is set, meaning if
76      * {@link #mUseLocalTime} is {@code true}, then this field represent local time, otherwise
77      * physical time.
78      */
79     private final String mTimeColumnName;
80 
81     private final WhereClauses mWhereClauses;
82     private final SqlJoin mSqlJoin;
83     private String mGroupByColumnName;
84     private int mGroupBySize = 1;
85     private final List<String> mAdditionalColumnsToFetch;
86     private final AggregateParams.PriorityAggregationExtraParams mPriorityParams;
87     private final boolean mUseLocalTime;
88     private List<Long> mTimeSplits;
89 
90     @SuppressWarnings("NullAway.Init") // TODO(b/317029272): fix this suppression
AggregateTableRequest( AggregateParams params, AggregationType<?> aggregationType, RecordHelper<?> recordHelper, WhereClauses whereClauses, boolean useLocalTime)91     public AggregateTableRequest(
92             AggregateParams params,
93             AggregationType<?> aggregationType,
94             RecordHelper<?> recordHelper,
95             WhereClauses whereClauses,
96             boolean useLocalTime) {
97         mTableName = params.getTableName();
98         mColumnNamesToAggregate = params.getColumnsToFetch();
99         mTimeColumnName = params.getTimeColumnName();
100         mAggregationType = aggregationType;
101         mRecordHelper = recordHelper;
102         mSqlJoin = params.getJoin();
103         mPriorityParams = params.getPriorityAggregationExtraParams();
104         mWhereClauses = whereClauses;
105         mAdditionalColumnsToFetch = new ArrayList<>();
106         mAdditionalColumnsToFetch.add(params.getTimeOffsetColumnName());
107         mAdditionalColumnsToFetch.add(mTimeColumnName);
108         String endTimeColumnName = params.getExtraTimeColumnName();
109         if (endTimeColumnName != null) {
110             mAdditionalColumnsToFetch.add(endTimeColumnName);
111         }
112         mUseLocalTime = useLocalTime;
113     }
114 
115     /**
116      * @return {@link AggregationType} for this request
117      */
getAggregationType()118     public AggregationType<?> getAggregationType() {
119         return mAggregationType;
120     }
121 
122     /**
123      * @return {@link RecordHelper} for this request
124      */
getRecordHelper()125     public RecordHelper<?> getRecordHelper() {
126         return mRecordHelper;
127     }
128 
129     /**
130      * @return results fetched after performing aggregate operation for this class.
131      *     <p>Note: Only available after the call to {@link
132      *     TransactionManager#populateWithAggregation} has been made
133      */
getAggregateResults()134     public List<AggregateResult<?>> getAggregateResults() {
135         List<AggregateResult<?>> aggregateResults = new ArrayList<>(mGroupBySize);
136         for (int i = 0; i < mGroupBySize; i++) {
137             aggregateResults.add(mAggregateResults.get(i));
138         }
139 
140         return aggregateResults;
141     }
142 
143     /** Returns SQL statement to get data origins for the aggregation operation */
getCommandToFetchAggregateMetadata()144     public String getCommandToFetchAggregateMetadata() {
145         final StringBuilder builder = new StringBuilder("SELECT DISTINCT ");
146         builder.append(APP_INFO_ID_COLUMN_NAME).append(", ");
147         return appendAggregateCommand(builder, /* isMetadata= */ true);
148     }
149 
150     /** Returns name of the main time column (start time for Interval, time for Instant records) */
getTimeColumnName()151     public String getTimeColumnName() {
152         return mTimeColumnName;
153     }
154 
155     /** Returns whether request is using local time instead of physical one. */
getUseLocalTime()156     public boolean getUseLocalTime() {
157         return mUseLocalTime;
158     }
159 
160     /** Returns SQL statement to perform aggregation operation */
161     @NonNull
getAggregationCommand()162     public String getAggregationCommand() {
163         final StringBuilder builder = new StringBuilder("SELECT ");
164         String aggCommand;
165         boolean usingPriority =
166                 StorageUtils.supportsPriority(
167                                 mRecordHelper.getRecordIdentifier(),
168                                 mAggregationType.getAggregateOperationType())
169                         || StorageUtils.isDerivedType(mRecordHelper.getRecordIdentifier());
170         if (usingPriority) {
171             for (String columnName : mColumnNamesToAggregate) {
172                 builder.append(columnName).append(", ");
173             }
174         } else {
175             aggCommand = getSqlCommandFor(mAggregationType.getAggregateOperationType());
176 
177             for (String columnName : mColumnNamesToAggregate) {
178                 builder.append(aggCommand)
179                         .append("(")
180                         .append(columnName)
181                         .append(")")
182                         .append(" as ")
183                         .append(columnName)
184                         .append(", ");
185             }
186         }
187 
188         if (mAdditionalColumnsToFetch != null) {
189             for (String additionalColumnToFetch : mAdditionalColumnsToFetch) {
190                 builder.append(additionalColumnToFetch).append(", ");
191             }
192         }
193 
194         return appendAggregateCommand(builder, usingPriority);
195     }
196 
197     /** Sets time filter for table request. */
setTimeFilter(long startTime, long endTime)198     public AggregateTableRequest setTimeFilter(long startTime, long endTime) {
199         // Return if the params will result in no impact on the query
200         if (startTime < 0 || endTime < startTime) {
201             return this;
202         }
203 
204         mTimeSplits = List.of(startTime, endTime);
205         return this;
206     }
207 
208     /** Sets group by fields. */
setGroupBy( String columnName, Period period, Duration duration, TimeRangeFilter timeRangeFilter)209     public void setGroupBy(
210             String columnName, Period period, Duration duration, TimeRangeFilter timeRangeFilter) {
211         mGroupByColumnName = columnName;
212         if (period != null) {
213             mTimeSplits = getGroupSplitsForPeriod(timeRangeFilter, period);
214         } else if (duration != null) {
215             mTimeSplits = getGroupSplitsForDuration(timeRangeFilter, duration);
216         } else {
217             throw new IllegalArgumentException(
218                     "Either aggregation period or duration should be not null");
219         }
220         mGroupBySize = mTimeSplits.size() - 1;
221 
222         if (Constants.DEBUG) {
223             Slog.d(
224                     TAG,
225                     "Group aggregation splits: "
226                             + mTimeSplits
227                             + " number of groups: "
228                             + mGroupBySize);
229         }
230     }
231 
onResultsFetched(Cursor cursor, Cursor metaDataCursor)232     public void onResultsFetched(Cursor cursor, Cursor metaDataCursor) {
233         if (StorageUtils.isDerivedType(mRecordHelper.getRecordIdentifier())) {
234             deriveAggregate(cursor);
235         } else if (StorageUtils.supportsPriority(
236                 mRecordHelper.getRecordIdentifier(),
237                 mAggregationType.getAggregateOperationType())) {
238             processPriorityRequest(cursor);
239         } else {
240             processNoPrioritiesRequest(cursor);
241         }
242 
243         updateResultWithDataOriginPackageNames(metaDataCursor);
244     }
245 
processPriorityRequest(Cursor cursor)246     private void processPriorityRequest(Cursor cursor) {
247         List<Long> priorityList =
248                 StorageUtils.getAppIdPriorityList(mRecordHelper.getRecordIdentifier());
249         PriorityRecordsAggregator aggregator =
250                 new PriorityRecordsAggregator(
251                         mTimeSplits,
252                         priorityList,
253                         mAggregationType.getAggregationTypeIdentifier(),
254                         mPriorityParams,
255                         mUseLocalTime);
256         aggregator.calculateAggregation(cursor);
257         AggregateResult<?> result;
258         for (int groupNumber = 0; groupNumber < mGroupBySize; groupNumber++) {
259             if (aggregator.getResultForGroup(groupNumber) == null) {
260                 continue;
261             }
262 
263             if (mAggregationType.getAggregateResultClass() == Long.class) {
264                 result =
265                         new AggregateResult<>(
266                                 aggregator.getResultForGroup(groupNumber).longValue());
267             } else {
268                 result = new AggregateResult<>(aggregator.getResultForGroup(groupNumber));
269             }
270             mAggregateResults.put(
271                     groupNumber,
272                     result.setZoneOffset(aggregator.getZoneOffsetForGroup(groupNumber)));
273         }
274 
275         if (Constants.DEBUG) {
276             Slog.d(TAG, "Priority aggregation result: " + mAggregateResults);
277         }
278     }
279 
processNoPrioritiesRequest(Cursor cursor)280     private void processNoPrioritiesRequest(Cursor cursor) {
281         while (cursor.moveToNext()) {
282             mAggregateResults.put(
283                     StorageUtils.getCursorInt(cursor, GROUP_BY_COLUMN_NAME),
284                     mRecordHelper.getAggregateResult(cursor, mAggregationType));
285         }
286     }
287 
288     @SuppressWarnings("NullAway") // TODO(b/317029272): fix this suppression
getSqlCommandFor(@ggregationType.AggregateOperationType int type)289     private static String getSqlCommandFor(@AggregationType.AggregateOperationType int type) {
290         return switch (type) {
291             case MAX -> "MAX";
292             case MIN -> "MIN";
293             case AVG -> "AVG";
294             case SUM -> "SUM";
295             case COUNT -> "COUNT";
296             default -> null;
297         };
298     }
299 
appendAggregateCommand(StringBuilder builder, boolean isMetadata)300     private String appendAggregateCommand(StringBuilder builder, boolean isMetadata) {
301         boolean useGroupBy = mGroupByColumnName != null && !isMetadata;
302         if (useGroupBy) {
303             builder.append(" CASE ");
304             int groupByIndex = 0;
305             for (int i = 0; i < mTimeSplits.size() - 1; i++) {
306                 builder.append(" WHEN ")
307                         .append(mTimeColumnName)
308                         .append(" >= ")
309                         .append(mTimeSplits.get(i))
310                         .append(" AND ")
311                         .append(mTimeColumnName)
312                         .append(" < ")
313                         .append(mTimeSplits.get(i + 1))
314                         .append(" THEN ")
315                         .append(groupByIndex++);
316             }
317             builder.append(" END " + GROUP_BY_COLUMN_NAME + " ");
318         } else {
319             builder.setLength(builder.length() - 2); // Remove the last 2 char i.e. ", "
320         }
321 
322         builder.append(" FROM ").append(mTableName);
323         if (mSqlJoin != null) {
324             builder.append(mSqlJoin.getJoinCommand());
325         }
326 
327         builder.append(mWhereClauses.get(/* withWhereKeyword= */ true));
328 
329         if (useGroupBy) {
330             builder.append(" GROUP BY " + GROUP_BY_COLUMN_NAME);
331         }
332 
333         OrderByClause orderByClause = new OrderByClause();
334         orderByClause.addOrderByClause(mTimeColumnName, true);
335         builder.append(orderByClause.getOrderBy());
336 
337         if (Constants.DEBUG) {
338             Slog.d(TAG, "Aggregation origin query: " + builder);
339         }
340 
341         return builder.toString();
342     }
343 
344     @SuppressWarnings("NullAway") // TODO(b/317029272): fix this suppression
updateResultWithDataOriginPackageNames(Cursor metaDataCursor)345     private void updateResultWithDataOriginPackageNames(Cursor metaDataCursor) {
346         List<Long> packageIds = new ArrayList<>();
347         while (metaDataCursor.moveToNext()) {
348             packageIds.add(StorageUtils.getCursorLong(metaDataCursor, APP_INFO_ID_COLUMN_NAME));
349         }
350         List<String> packageNames = AppInfoHelper.getInstance().getPackageNames(packageIds);
351 
352         mAggregateResults.replaceAll(
353                 (n, v) -> mAggregateResults.get(n).setDataOrigins(packageNames));
354     }
355 
getGroupSplitIntervals()356     public List<Pair<Long, Long>> getGroupSplitIntervals() {
357         List<Pair<Long, Long>> groupIntervals = new ArrayList<>();
358         long previous = mTimeSplits.get(0);
359         for (int i = 1; i < mTimeSplits.size(); i++) {
360             Pair<Long, Long> pair = new Pair<>(previous, mTimeSplits.get(i));
361             groupIntervals.add(pair);
362             previous = mTimeSplits.get(i);
363         }
364 
365         return groupIntervals;
366     }
367 
getGroupSplitsForPeriod(TimeRangeFilter timeFilter, Period period)368     private List<Long> getGroupSplitsForPeriod(TimeRangeFilter timeFilter, Period period) {
369         LocalDateTime filterStart = ((LocalTimeRangeFilter) timeFilter).getStartTime();
370         LocalDateTime filterEnd = ((LocalTimeRangeFilter) timeFilter).getEndTime();
371 
372         List<Long> splits = new ArrayList<>();
373         splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(filterStart));
374 
375         LocalDateTime currentEnd = filterStart.plus(period);
376         while (!currentEnd.isAfter(filterEnd)) {
377             splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(currentEnd));
378             currentEnd = currentEnd.plus(period);
379 
380             if (splits.size() > MAX_NUMBER_OF_GROUPS) {
381                 throw new IllegalArgumentException(
382                         "Number of groups must not exceed " + MAX_NUMBER_OF_GROUPS);
383             }
384         }
385 
386         // If the last group doesn't fit the rest of the window, we cut it up to filterEnd
387         if (splits.get(splits.size() - 1) < TimeRangeFilterHelper.getMillisOfLocalTime(filterEnd)) {
388             splits.add(TimeRangeFilterHelper.getMillisOfLocalTime(filterEnd));
389         }
390         return splits;
391     }
392 
getGroupSplitsForDuration( TimeRangeFilter timeRangeFilter, Duration duration)393     private List<Long> getGroupSplitsForDuration(
394             TimeRangeFilter timeRangeFilter, Duration duration) {
395         long groupByStart = TimeRangeFilterHelper.getFilterStartTimeMillis(timeRangeFilter);
396         long groupByEnd = TimeRangeFilterHelper.getFilterEndTimeMillis(timeRangeFilter);
397         long groupDurationMillis = duration.toMillis();
398 
399         if ((groupByEnd - groupByStart) / groupDurationMillis > MAX_NUMBER_OF_GROUPS) {
400             throw new IllegalArgumentException(
401                     "Number of buckets must not exceed " + MAX_NUMBER_OF_GROUPS);
402         }
403 
404         List<Long> splits = new ArrayList<>();
405         splits.add(groupByStart);
406         long currentEnd = groupByStart + groupDurationMillis;
407         while (currentEnd <= groupByEnd) {
408             splits.add(currentEnd);
409             currentEnd += groupDurationMillis;
410         }
411 
412         // If the last group doesn't fit the rest of the window, we cut it up to filterEnd
413         if (splits.get(splits.size() - 1) < groupByEnd) {
414             splits.add(groupByEnd);
415         }
416         return splits;
417     }
418 
deriveAggregate(Cursor cursor)419     private void deriveAggregate(Cursor cursor) {
420         double[] derivedAggregateArray = mRecordHelper.deriveAggregate(cursor, this);
421         int index = 0;
422         cursor.moveToFirst();
423         for (double aggregate : derivedAggregateArray) {
424             mAggregateResults.put(
425                     index, mRecordHelper.getAggregateResult(cursor, mAggregationType, aggregate));
426             index++;
427         }
428     }
429 }
430