1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.ondevicepersonalization.services.data.events;
18 
19 import android.content.ContentValues;
20 import android.content.Context;
21 import android.database.Cursor;
22 import android.database.SQLException;
23 import android.database.sqlite.SQLiteDatabase;
24 import android.database.sqlite.SQLiteOpenHelper;
25 
26 import com.android.ondevicepersonalization.internal.util.LoggerFactory;
27 import com.android.ondevicepersonalization.services.util.OnDevicePersonalizationFlatbufferUtils;
28 
29 import java.util.HashMap;
30 import java.util.List;
31 import java.util.Map;
32 import java.util.function.Function;
33 import java.util.stream.Collectors;
34 
35 /**
36  * Dao used to manage and create in-memory table for joined Events and Queries tables
37  */
38 public class JoinedTableDao {
39     /** Map of column name to {@link ColumnSchema} of columns provided by OnDevicePersonalization */
40     public static final Map<String, ColumnSchema> ODP_PROVIDED_COLUMNS;
41     // TODO(298682670): Finalize provided column and table names.
42     public static final String SERVICE_NAME_COL = "serviceName";
43     public static final String TYPE_COL = "type";
44     public static final String EVENT_TIME_MILLIS_COL = "eventTimeMillis";
45     public static final String QUERY_TIME_MILLIS_COL = "queryTimeMillis";
46     public static final String TABLE_NAME = "odp_joined_table";
47     private static final String TAG = "JoinedTableDao";
48     private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
49 
50     static {
51         ODP_PROVIDED_COLUMNS = new HashMap<>();
ODP_PROVIDED_COLUMNS.put(SERVICE_NAME_COL, new ColumnSchema.Builder().setName( SERVICE_NAME_COL).setType(ColumnSchema.SQL_DATA_TYPE_TEXT).build())52         ODP_PROVIDED_COLUMNS.put(SERVICE_NAME_COL, new ColumnSchema.Builder().setName(
53                 SERVICE_NAME_COL).setType(ColumnSchema.SQL_DATA_TYPE_TEXT).build());
ODP_PROVIDED_COLUMNS.put(TYPE_COL, new ColumnSchema.Builder().setName(TYPE_COL).setType( ColumnSchema.SQL_DATA_TYPE_INTEGER).build())54         ODP_PROVIDED_COLUMNS.put(TYPE_COL, new ColumnSchema.Builder().setName(TYPE_COL).setType(
55                 ColumnSchema.SQL_DATA_TYPE_INTEGER).build());
ODP_PROVIDED_COLUMNS.put(EVENT_TIME_MILLIS_COL, new ColumnSchema.Builder().setName( EVENT_TIME_MILLIS_COL).setType(ColumnSchema.SQL_DATA_TYPE_INTEGER).build())56         ODP_PROVIDED_COLUMNS.put(EVENT_TIME_MILLIS_COL, new ColumnSchema.Builder().setName(
57                 EVENT_TIME_MILLIS_COL).setType(ColumnSchema.SQL_DATA_TYPE_INTEGER).build());
ODP_PROVIDED_COLUMNS.put(QUERY_TIME_MILLIS_COL, new ColumnSchema.Builder().setName( QUERY_TIME_MILLIS_COL).setType(ColumnSchema.SQL_DATA_TYPE_INTEGER).build())58         ODP_PROVIDED_COLUMNS.put(QUERY_TIME_MILLIS_COL, new ColumnSchema.Builder().setName(
59                 QUERY_TIME_MILLIS_COL).setType(ColumnSchema.SQL_DATA_TYPE_INTEGER).build());
60     }
61 
62     private final SQLiteOpenHelper mDbHelper;
63     private final Map<String, ColumnSchema> mColumns;
64 
JoinedTableDao(List<ColumnSchema> columnSchemaList, long fromEventId, long fromQueryId, Context context)65     public JoinedTableDao(List<ColumnSchema> columnSchemaList, long fromEventId, long fromQueryId,
66             Context context) {
67         if (!validateColumns(columnSchemaList)) {
68             throw new IllegalArgumentException("Provided columns are invalid.");
69         }
70         // Move the List to a HashMap <ColumnName, ColumnSchema> for easier access.
71         mColumns = columnSchemaList.stream().collect(Collectors.toMap(
72                 ColumnSchema::getName,
73                 Function.identity(),
74                 (v1, v2) -> {
75                     // Throw on duplicate keys.
76                     throw new IllegalArgumentException("Duplicate key found in columnSchemaList");
77                 },
78                 HashMap::new));
79         mDbHelper = createInMemoryTable(columnSchemaList, context);
80         populateTable(fromEventId, fromQueryId, context);
81     }
82 
createInMemoryTable(List<ColumnSchema> columnSchemaList, Context context)83     private static SQLiteOpenHelper createInMemoryTable(List<ColumnSchema> columnSchemaList,
84             Context context) {
85         List<String> columns = columnSchemaList.stream().map(ColumnSchema::toString).collect(
86                 Collectors.toList());
87         String createTableStatement = "CREATE TABLE IF NOT EXISTS " + TABLE_NAME + " ("
88                 + String.join(",", columns) + ")";
89         SQLiteOpenHelper sqLiteOpenHelper = new SQLiteOpenHelper(context, null, null, 1) {
90             @Override
91             public void onCreate(SQLiteDatabase db) {
92                 // Do nothing.
93             }
94 
95             @Override
96             public void onUpgrade(SQLiteDatabase db, int oldVersion, int newVersion) {
97                 // Do nothing. Should never be called.
98             }
99         };
100 
101         try {
102             sqLiteOpenHelper.getReadableDatabase().execSQL(createTableStatement);
103         } catch (SQLException e) {
104             sLogger.e(e, TAG + " : Failed to create JoinedTable database in memory.");
105             throw new IllegalStateException(e);
106         }
107         return sqLiteOpenHelper;
108     }
109 
validateColumns(List<ColumnSchema> columnSchemaList)110     private static boolean validateColumns(List<ColumnSchema> columnSchemaList) {
111         if (columnSchemaList.size() == 0) {
112             sLogger.d(TAG + ": Empty columnSchemaList provided");
113             return false;
114         }
115         for (ColumnSchema columnSchema : columnSchemaList) {
116             // Validate any ODP_PROVIDED_COLUMNS are the correct type
117             if (ODP_PROVIDED_COLUMNS.containsKey(columnSchema.getName())) {
118                 ColumnSchema expected = ODP_PROVIDED_COLUMNS.get(columnSchema.getName());
119                 if (expected.getType() != columnSchema.getType()) {
120                     sLogger.d(TAG
121                                     + ": ODP column %s of type %s provided does not match "
122                                     + "expected type %s",
123                             columnSchema.getName(), columnSchema.getType(), expected.getType());
124                     return false;
125                 }
126             }
127         }
128         // TODO(298225729): Additional validation on column name formatting.
129         return true;
130     }
131 
132     /**
133      * Executes the given query on the in-memory db.
134      *
135      * @return Cursor holding result of the query.
136      */
rawQuery(String sql)137     public Cursor rawQuery(String sql) {
138         SQLiteDatabase db = mDbHelper.getReadableDatabase();
139         // TODO(298225729): Determine return format.
140         return db.rawQuery(sql, null);
141     }
142 
populateTable(long fromEventId, long fromQueryId, Context context)143     private void populateTable(long fromEventId, long fromQueryId, Context context) {
144         EventsDao eventsDao = EventsDao.getInstance(context);
145         List<JoinedEvent> joinedEventList = eventsDao.readAllNewRows(fromEventId,
146                 fromQueryId);
147 
148         SQLiteDatabase db = mDbHelper.getWritableDatabase();
149         try {
150             db.beginTransactionNonExclusive();
151             for (JoinedEvent joinedEvent : joinedEventList) {
152                 if (joinedEvent.getEventId() == 0) {
153                     // Process Query-only rows
154                     if (joinedEvent.getQueryData() != null) {
155                         List<ContentValues> queryFieldRows =
156                                 OnDevicePersonalizationFlatbufferUtils
157                                         .getContentValuesFromQueryData(
158                                                 joinedEvent.getQueryData());
159                         for (ContentValues queryRow : queryFieldRows) {
160                             ContentValues insertValues = new ContentValues();
161                             insertValues.putAll(extractValidColumns(queryRow));
162                             insertValues.putAll(addProvidedColumns(joinedEvent));
163                             long insertResult = db.insert(TABLE_NAME, null, insertValues);
164                             if (insertResult == -1) {
165                                 throw new IllegalStateException("Failed to insert row into SQL DB");
166                             }
167                         }
168                     }
169                 } else {
170                     ContentValues insertValues = new ContentValues();
171                     // Add eventData columns
172                     if (joinedEvent.getEventData() != null) {
173                         ContentValues eventData =
174                                 OnDevicePersonalizationFlatbufferUtils
175                                         .getContentValuesFromEventData(
176                                                 joinedEvent.getEventData());
177                         insertValues.putAll(extractValidColumns(eventData));
178                     }
179                     // Add queryData columns
180                     if (joinedEvent.getQueryData() != null) {
181                         ContentValues queryData =
182                                 OnDevicePersonalizationFlatbufferUtils
183                                         .getContentValuesRowFromQueryData(
184                                                 joinedEvent.getQueryData(),
185                                                 joinedEvent.getRowIndex());
186                         insertValues.putAll(extractValidColumns(queryData));
187                     }
188                     // Add ODP provided columns
189                     insertValues.putAll(addProvidedColumns(joinedEvent));
190                     long insertResult = db.insert(TABLE_NAME, null, insertValues);
191                     if (insertResult == -1) {
192                         throw new IllegalStateException("Failed to insert row into SQL DB");
193                     }
194                 }
195             }
196             db.setTransactionSuccessful();
197         } finally {
198             db.endTransaction();
199         }
200     }
201 
addProvidedColumns(JoinedEvent joinedEvent)202     private ContentValues addProvidedColumns(JoinedEvent joinedEvent) {
203         ContentValues result = new ContentValues();
204         if (mColumns.containsKey(SERVICE_NAME_COL)) {
205             result.put(SERVICE_NAME_COL,
206                     joinedEvent.getServiceName());
207         }
208         if (mColumns.containsKey(TYPE_COL)) {
209             result.put(TYPE_COL, joinedEvent.getType());
210         }
211         if (mColumns.containsKey(EVENT_TIME_MILLIS_COL)) {
212             result.put(EVENT_TIME_MILLIS_COL, joinedEvent.getEventTimeMillis());
213         }
214         if (mColumns.containsKey(QUERY_TIME_MILLIS_COL)) {
215             result.put(QUERY_TIME_MILLIS_COL, joinedEvent.getQueryTimeMillis());
216         }
217         return result;
218     }
219 
extractValidColumns(ContentValues data)220     private ContentValues extractValidColumns(ContentValues data) {
221         ContentValues result = new ContentValues();
222         for (String key : data.keySet()) {
223             if (mColumns.containsKey(key)) {
224                 Object value = data.get(key);
225                 int sqlType = mColumns.get(key).getType();
226                 if (value instanceof Byte) {
227                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) {
228                         result.put(key, (Byte) value);
229                     }
230                 } else if (value instanceof Short) {
231                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) {
232                         result.put(key, (Short) value);
233                     }
234                 } else if (value instanceof Integer) {
235                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) {
236                         result.put(key, (Integer) value);
237                     }
238                 } else if (value instanceof Long) {
239                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) {
240                         result.put(key, (Long) value);
241                     }
242                 } else if (value instanceof Float) {
243                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_REAL) {
244                         result.put(key, (Float) value);
245                     }
246                 } else if (value instanceof Double) {
247                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_REAL) {
248                         result.put(key, (Double) value);
249                     }
250                 } else if (value instanceof String) {
251                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_TEXT) {
252                         result.put(key, (String) value);
253                     }
254                 } else if (value instanceof byte[]) {
255                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_BLOB) {
256                         result.put(key, (byte[]) value);
257                     }
258                 } else if (value instanceof Boolean) {
259                     if (sqlType == ColumnSchema.SQL_DATA_TYPE_INTEGER) {
260                         result.put(key, (Boolean) value);
261                     }
262                 }
263             }
264         }
265         return result;
266     }
267 }
268