1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.data;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertFalse;
22 
23 import android.content.Context;
24 import android.database.Cursor;
25 import android.database.sqlite.SQLiteDatabase;
26 import android.util.Log;
27 
28 import androidx.test.core.app.ApplicationProvider;
29 
30 import com.android.adservices.data.measurement.MeasurementDbHelper;
31 import com.android.adservices.data.shared.SharedDbHelper;
32 
33 import com.google.common.collect.ImmutableSet;
34 
35 import java.util.ArrayList;
36 import java.util.Collections;
37 import java.util.List;
38 import java.util.Set;
39 
40 public final class DbTestUtil {
41     private static final Context sContext = ApplicationProvider.getApplicationContext();
42     private static final String DATABASE_NAME_FOR_TEST = "adservices_test.db";
43     private static final String MSMT_DATABASE_NAME_FOR_TEST = "adservices_msmt_test.db";
44     private static final String SHARED_DATABASE_NAME_FOR_TEST = "adservices_shared_test.db";
45 
46     private static DbHelper sSingleton;
47     private static MeasurementDbHelper sMsmtSingleton;
48     private static SharedDbHelper sSharedSingleton;
49 
50     /** Erases all data from the table rows */
deleteTable(String tableName)51     public static void deleteTable(String tableName) {
52         SQLiteDatabase db = getDbHelperForTest().safeGetWritableDatabase();
53         if (db == null) {
54             return;
55         }
56 
57         db.delete(tableName, /* whereClause= */ null, /* whereArgs= */ null);
58     }
59 
60     /**
61      * Create an instance of database instance for testing.
62      *
63      * @return a test database
64      */
getDbHelperForTest()65     public static DbHelper getDbHelperForTest() {
66         synchronized (DbHelper.class) {
67             if (sSingleton == null) {
68                 sSingleton =
69                         new DbHelper(sContext, DATABASE_NAME_FOR_TEST, DbHelper.DATABASE_VERSION_7);
70             }
71             return sSingleton;
72         }
73     }
74 
getMeasurementDbHelperForTest()75     public static MeasurementDbHelper getMeasurementDbHelperForTest() {
76         synchronized (MeasurementDbHelper.class) {
77             if (sMsmtSingleton == null) {
78                 sMsmtSingleton =
79                         new MeasurementDbHelper(
80                                 sContext,
81                                 MSMT_DATABASE_NAME_FOR_TEST,
82                                 MeasurementDbHelper.CURRENT_DATABASE_VERSION,
83                                 getDbHelperForTest());
84             }
85             return sMsmtSingleton;
86         }
87     }
88 
getSharedDbHelperForTest()89     public static SharedDbHelper getSharedDbHelperForTest() {
90         synchronized (SharedDbHelper.class) {
91             if (sSharedSingleton == null) {
92                 sSharedSingleton =
93                         new SharedDbHelper(
94                                 sContext,
95                                 SHARED_DATABASE_NAME_FOR_TEST,
96                                 SharedDbHelper.DATABASE_VERSION_V4,
97                                 getDbHelperForTest());
98             }
99             return sSharedSingleton;
100         }
101     }
102 
103     /** Return true if table exists in the DB and column count matches. */
doesTableExistAndColumnCountMatch( SQLiteDatabase db, String tableName, int columnCount)104     public static boolean doesTableExistAndColumnCountMatch(
105             SQLiteDatabase db, String tableName, int columnCount) {
106         final Set<String> tableColumns = getTableColumns(db, tableName);
107         int actualCol = tableColumns.size();
108         Log.d("DbTestUtil_log_test,", " table name: " + tableName + " column count: " + actualCol);
109         return tableColumns.size() == columnCount;
110     }
111 
112     /** Returns column names of the table. */
getTableColumns(SQLiteDatabase db, String tableName)113     public static Set<String> getTableColumns(SQLiteDatabase db, String tableName) {
114         String query =
115                 "select p.name from sqlite_master s "
116                         + "join pragma_table_info(s.name) p "
117                         + "where s.tbl_name = '"
118                         + tableName
119                         + "'";
120         Cursor cursor = db.rawQuery(query, null);
121         if (cursor == null) {
122             throw new IllegalArgumentException("Cursor is null.");
123         }
124 
125         ImmutableSet.Builder<String> tableColumns = ImmutableSet.builder();
126         while (cursor.moveToNext()) {
127             tableColumns.add(cursor.getString(0));
128         }
129 
130         return tableColumns.build();
131     }
132 
133     /** Return true if the given index exists in the DB. */
doesIndexExist(SQLiteDatabase db, String index)134     public static boolean doesIndexExist(SQLiteDatabase db, String index) {
135         String query = "SELECT * FROM sqlite_master WHERE type='index' and name='" + index + "'";
136         Cursor cursor = db.rawQuery(query, null);
137         return cursor != null && cursor.getCount() > 0;
138     }
139 
doesTableExist(SQLiteDatabase db, String table)140     public static boolean doesTableExist(SQLiteDatabase db, String table) {
141         String query = "SELECT * FROM sqlite_master WHERE type='table' and name='" + table + "'";
142         Cursor cursor = db.rawQuery(query, null);
143         return cursor != null && cursor.getCount() > 0;
144     }
145 
146     /** Return test database name */
getDatabaseNameForTest()147     public static String getDatabaseNameForTest() {
148         return DATABASE_NAME_FOR_TEST;
149     }
150 
assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb)151     public static void assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb) {
152         List<String> expectedTables = getTables(expectedDb);
153         List<String> actualTables = getTables(actualDb);
154         assertArrayEquals(expectedTables.toArray(), actualTables.toArray());
155         assertTableSchemaEqual(expectedDb, actualDb, expectedTables);
156         assertIndexesEqual(expectedDb, actualDb, expectedTables);
157     }
158 
assertMeasurementTablesDoNotExist(SQLiteDatabase db)159     public static void assertMeasurementTablesDoNotExist(SQLiteDatabase db) {
160         assertFalse(doesTableExist(db, "msmt_source"));
161         assertFalse(doesTableExist(db, "msmt_trigger"));
162         assertFalse(doesTableExist(db, "msmt_async_registration_contract"));
163         assertFalse(doesTableExist(db, "msmt_event_report"));
164         assertFalse(doesTableExist(db, "msmt_attribution"));
165         assertFalse(doesTableExist(db, "msmt_aggregate_report"));
166         assertFalse(doesTableExist(db, "msmt_aggregate_encryption_key"));
167         assertFalse(doesTableExist(db, "msmt_debug_report"));
168         assertFalse(doesTableExist(db, "msmt_xna_ignored_sources"));
169     }
170 
assertTableSchemaEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames)171     private static void assertTableSchemaEqual(
172             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames) {
173         for (String tableName : tableNames) {
174             Cursor columnsCursorExpected =
175                     expectedDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null);
176             Cursor columnsCursorActual =
177                     actualDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null);
178             assertEquals(
179                     "Table columns mismatch for " + tableName,
180                     columnsCursorExpected.getCount(),
181                     columnsCursorActual.getCount());
182 
183             // Checks the columns in order. Newly created columns should be inserted as the end.
184             while (columnsCursorExpected.moveToNext() && columnsCursorActual.moveToNext()) {
185                 assertEquals(
186                         "Column mismatch for " + tableName,
187                         columnsCursorExpected.getString(
188                                 columnsCursorExpected.getColumnIndex("name")),
189                         columnsCursorActual.getString(columnsCursorActual.getColumnIndex("name")));
190                 assertEquals(
191                         "Column mismatch for " + tableName,
192                         columnsCursorExpected.getString(
193                                 columnsCursorExpected.getColumnIndex("type")),
194                         columnsCursorActual.getString(columnsCursorActual.getColumnIndex("type")));
195                 assertEquals(
196                         "Column mismatch for " + tableName,
197                         columnsCursorExpected.getInt(
198                                 columnsCursorExpected.getColumnIndex("notnull")),
199                         columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("notnull")));
200                 assertEquals(
201                         "Column mismatch for " + tableName,
202                         columnsCursorExpected.getString(
203                                 columnsCursorExpected.getColumnIndex("dflt_value")),
204                         columnsCursorActual.getString(
205                                 columnsCursorActual.getColumnIndex("dflt_value")));
206                 assertEquals(
207                         "Column mismatch for " + tableName,
208                         columnsCursorExpected.getInt(columnsCursorExpected.getColumnIndex("pk")),
209                         columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("pk")));
210             }
211 
212             columnsCursorExpected.close();
213             columnsCursorActual.close();
214         }
215     }
216 
getTables(SQLiteDatabase db)217     private static List<String> getTables(SQLiteDatabase db) {
218         String listTableQuery = "SELECT name FROM sqlite_master where type = 'table'";
219         List<String> tables = new ArrayList<>();
220         try (Cursor cursor = db.rawQuery(listTableQuery, null)) {
221             while (cursor.moveToNext()) {
222                 tables.add(cursor.getString(cursor.getColumnIndex("name")));
223             }
224         }
225         Collections.sort(tables);
226         return tables;
227     }
228 
assertIndexesEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables)229     private static void assertIndexesEqual(
230             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables) {
231         for (String tableName : tables) {
232             String indexListQuery =
233                     "SELECT name FROM sqlite_master where type = 'index' AND tbl_name = '"
234                             + tableName
235                             + "' ORDER BY name ASC";
236             Cursor indexListCursorExpected = expectedDb.rawQuery(indexListQuery, null);
237             Cursor indexListCursorActual = actualDb.rawQuery(indexListQuery, null);
238             assertEquals(
239                     "Table indexes mismatch for " + tableName,
240                     indexListCursorExpected.getCount(),
241                     indexListCursorActual.getCount());
242 
243             while (indexListCursorExpected.moveToNext() && indexListCursorActual.moveToNext()) {
244                 String expectedIndexName =
245                         indexListCursorExpected.getString(
246                                 indexListCursorExpected.getColumnIndex("name"));
247                 assertEquals(
248                         "Index mismatch for " + tableName,
249                         expectedIndexName,
250                         indexListCursorActual.getString(
251                                 indexListCursorActual.getColumnIndex("name")));
252 
253                 assertIndexInfoEqual(expectedDb, actualDb, expectedIndexName);
254             }
255 
256             indexListCursorExpected.close();
257             indexListCursorActual.close();
258         }
259     }
260 
assertIndexInfoEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName)261     private static void assertIndexInfoEqual(
262             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName) {
263         Cursor indexInfoCursorExpected =
264                 expectedDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null);
265         Cursor indexInfoCursorActual =
266                 actualDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null);
267         assertEquals(
268                 "Index columns count mismatch for " + indexName,
269                 indexInfoCursorExpected.getCount(),
270                 indexInfoCursorActual.getCount());
271 
272         while (indexInfoCursorExpected.moveToNext() && indexInfoCursorActual.moveToNext()) {
273             assertEquals(
274                     "Index info mismatch for " + indexName,
275                     indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("seqno")),
276                     indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("seqno")));
277             assertEquals(
278                     "Index info mismatch for " + indexName,
279                     indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("cid")),
280                     indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("cid")));
281             assertEquals(
282                     "Index info mismatch for " + indexName,
283                     indexInfoCursorExpected.getString(
284                             indexInfoCursorExpected.getColumnIndex("name")),
285                     indexInfoCursorActual.getString(indexInfoCursorActual.getColumnIndex("name")));
286         }
287 
288         indexInfoCursorExpected.close();
289         indexInfoCursorActual.close();
290     }
291 }
292