1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.data; 18 19 import static org.junit.Assert.assertArrayEquals; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertFalse; 22 23 import android.content.Context; 24 import android.database.Cursor; 25 import android.database.sqlite.SQLiteDatabase; 26 import android.util.Log; 27 28 import androidx.test.core.app.ApplicationProvider; 29 30 import com.android.adservices.data.measurement.MeasurementDbHelper; 31 import com.android.adservices.data.shared.SharedDbHelper; 32 33 import com.google.common.collect.ImmutableSet; 34 35 import java.util.ArrayList; 36 import java.util.Collections; 37 import java.util.List; 38 import java.util.Set; 39 40 public final class DbTestUtil { 41 private static final Context sContext = ApplicationProvider.getApplicationContext(); 42 private static final String DATABASE_NAME_FOR_TEST = "adservices_test.db"; 43 private static final String MSMT_DATABASE_NAME_FOR_TEST = "adservices_msmt_test.db"; 44 private static final String SHARED_DATABASE_NAME_FOR_TEST = "adservices_shared_test.db"; 45 46 private static DbHelper sSingleton; 47 private static MeasurementDbHelper sMsmtSingleton; 48 private static SharedDbHelper sSharedSingleton; 49 50 /** Erases all data from the table rows */ deleteTable(String tableName)51 public static void deleteTable(String tableName) { 52 SQLiteDatabase db = getDbHelperForTest().safeGetWritableDatabase(); 53 if (db == null) { 54 return; 55 } 56 57 db.delete(tableName, /* whereClause= */ null, /* whereArgs= */ null); 58 } 59 60 /** 61 * Create an instance of database instance for testing. 62 * 63 * @return a test database 64 */ getDbHelperForTest()65 public static DbHelper getDbHelperForTest() { 66 synchronized (DbHelper.class) { 67 if (sSingleton == null) { 68 sSingleton = 69 new DbHelper(sContext, DATABASE_NAME_FOR_TEST, DbHelper.DATABASE_VERSION_7); 70 } 71 return sSingleton; 72 } 73 } 74 getMeasurementDbHelperForTest()75 public static MeasurementDbHelper getMeasurementDbHelperForTest() { 76 synchronized (MeasurementDbHelper.class) { 77 if (sMsmtSingleton == null) { 78 sMsmtSingleton = 79 new MeasurementDbHelper( 80 sContext, 81 MSMT_DATABASE_NAME_FOR_TEST, 82 MeasurementDbHelper.CURRENT_DATABASE_VERSION, 83 getDbHelperForTest()); 84 } 85 return sMsmtSingleton; 86 } 87 } 88 getSharedDbHelperForTest()89 public static SharedDbHelper getSharedDbHelperForTest() { 90 synchronized (SharedDbHelper.class) { 91 if (sSharedSingleton == null) { 92 sSharedSingleton = 93 new SharedDbHelper( 94 sContext, 95 SHARED_DATABASE_NAME_FOR_TEST, 96 SharedDbHelper.DATABASE_VERSION_V4, 97 getDbHelperForTest()); 98 } 99 return sSharedSingleton; 100 } 101 } 102 103 /** Return true if table exists in the DB and column count matches. */ doesTableExistAndColumnCountMatch( SQLiteDatabase db, String tableName, int columnCount)104 public static boolean doesTableExistAndColumnCountMatch( 105 SQLiteDatabase db, String tableName, int columnCount) { 106 final Set<String> tableColumns = getTableColumns(db, tableName); 107 int actualCol = tableColumns.size(); 108 Log.d("DbTestUtil_log_test,", " table name: " + tableName + " column count: " + actualCol); 109 return tableColumns.size() == columnCount; 110 } 111 112 /** Returns column names of the table. */ getTableColumns(SQLiteDatabase db, String tableName)113 public static Set<String> getTableColumns(SQLiteDatabase db, String tableName) { 114 String query = 115 "select p.name from sqlite_master s " 116 + "join pragma_table_info(s.name) p " 117 + "where s.tbl_name = '" 118 + tableName 119 + "'"; 120 Cursor cursor = db.rawQuery(query, null); 121 if (cursor == null) { 122 throw new IllegalArgumentException("Cursor is null."); 123 } 124 125 ImmutableSet.Builder<String> tableColumns = ImmutableSet.builder(); 126 while (cursor.moveToNext()) { 127 tableColumns.add(cursor.getString(0)); 128 } 129 130 return tableColumns.build(); 131 } 132 133 /** Return true if the given index exists in the DB. */ doesIndexExist(SQLiteDatabase db, String index)134 public static boolean doesIndexExist(SQLiteDatabase db, String index) { 135 String query = "SELECT * FROM sqlite_master WHERE type='index' and name='" + index + "'"; 136 Cursor cursor = db.rawQuery(query, null); 137 return cursor != null && cursor.getCount() > 0; 138 } 139 doesTableExist(SQLiteDatabase db, String table)140 public static boolean doesTableExist(SQLiteDatabase db, String table) { 141 String query = "SELECT * FROM sqlite_master WHERE type='table' and name='" + table + "'"; 142 Cursor cursor = db.rawQuery(query, null); 143 return cursor != null && cursor.getCount() > 0; 144 } 145 146 /** Return test database name */ getDatabaseNameForTest()147 public static String getDatabaseNameForTest() { 148 return DATABASE_NAME_FOR_TEST; 149 } 150 assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb)151 public static void assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb) { 152 List<String> expectedTables = getTables(expectedDb); 153 List<String> actualTables = getTables(actualDb); 154 assertArrayEquals(expectedTables.toArray(), actualTables.toArray()); 155 assertTableSchemaEqual(expectedDb, actualDb, expectedTables); 156 assertIndexesEqual(expectedDb, actualDb, expectedTables); 157 } 158 assertMeasurementTablesDoNotExist(SQLiteDatabase db)159 public static void assertMeasurementTablesDoNotExist(SQLiteDatabase db) { 160 assertFalse(doesTableExist(db, "msmt_source")); 161 assertFalse(doesTableExist(db, "msmt_trigger")); 162 assertFalse(doesTableExist(db, "msmt_async_registration_contract")); 163 assertFalse(doesTableExist(db, "msmt_event_report")); 164 assertFalse(doesTableExist(db, "msmt_attribution")); 165 assertFalse(doesTableExist(db, "msmt_aggregate_report")); 166 assertFalse(doesTableExist(db, "msmt_aggregate_encryption_key")); 167 assertFalse(doesTableExist(db, "msmt_debug_report")); 168 assertFalse(doesTableExist(db, "msmt_xna_ignored_sources")); 169 } 170 assertTableSchemaEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames)171 private static void assertTableSchemaEqual( 172 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames) { 173 for (String tableName : tableNames) { 174 Cursor columnsCursorExpected = 175 expectedDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null); 176 Cursor columnsCursorActual = 177 actualDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null); 178 assertEquals( 179 "Table columns mismatch for " + tableName, 180 columnsCursorExpected.getCount(), 181 columnsCursorActual.getCount()); 182 183 // Checks the columns in order. Newly created columns should be inserted as the end. 184 while (columnsCursorExpected.moveToNext() && columnsCursorActual.moveToNext()) { 185 assertEquals( 186 "Column mismatch for " + tableName, 187 columnsCursorExpected.getString( 188 columnsCursorExpected.getColumnIndex("name")), 189 columnsCursorActual.getString(columnsCursorActual.getColumnIndex("name"))); 190 assertEquals( 191 "Column mismatch for " + tableName, 192 columnsCursorExpected.getString( 193 columnsCursorExpected.getColumnIndex("type")), 194 columnsCursorActual.getString(columnsCursorActual.getColumnIndex("type"))); 195 assertEquals( 196 "Column mismatch for " + tableName, 197 columnsCursorExpected.getInt( 198 columnsCursorExpected.getColumnIndex("notnull")), 199 columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("notnull"))); 200 assertEquals( 201 "Column mismatch for " + tableName, 202 columnsCursorExpected.getString( 203 columnsCursorExpected.getColumnIndex("dflt_value")), 204 columnsCursorActual.getString( 205 columnsCursorActual.getColumnIndex("dflt_value"))); 206 assertEquals( 207 "Column mismatch for " + tableName, 208 columnsCursorExpected.getInt(columnsCursorExpected.getColumnIndex("pk")), 209 columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("pk"))); 210 } 211 212 columnsCursorExpected.close(); 213 columnsCursorActual.close(); 214 } 215 } 216 getTables(SQLiteDatabase db)217 private static List<String> getTables(SQLiteDatabase db) { 218 String listTableQuery = "SELECT name FROM sqlite_master where type = 'table'"; 219 List<String> tables = new ArrayList<>(); 220 try (Cursor cursor = db.rawQuery(listTableQuery, null)) { 221 while (cursor.moveToNext()) { 222 tables.add(cursor.getString(cursor.getColumnIndex("name"))); 223 } 224 } 225 Collections.sort(tables); 226 return tables; 227 } 228 assertIndexesEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables)229 private static void assertIndexesEqual( 230 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables) { 231 for (String tableName : tables) { 232 String indexListQuery = 233 "SELECT name FROM sqlite_master where type = 'index' AND tbl_name = '" 234 + tableName 235 + "' ORDER BY name ASC"; 236 Cursor indexListCursorExpected = expectedDb.rawQuery(indexListQuery, null); 237 Cursor indexListCursorActual = actualDb.rawQuery(indexListQuery, null); 238 assertEquals( 239 "Table indexes mismatch for " + tableName, 240 indexListCursorExpected.getCount(), 241 indexListCursorActual.getCount()); 242 243 while (indexListCursorExpected.moveToNext() && indexListCursorActual.moveToNext()) { 244 String expectedIndexName = 245 indexListCursorExpected.getString( 246 indexListCursorExpected.getColumnIndex("name")); 247 assertEquals( 248 "Index mismatch for " + tableName, 249 expectedIndexName, 250 indexListCursorActual.getString( 251 indexListCursorActual.getColumnIndex("name"))); 252 253 assertIndexInfoEqual(expectedDb, actualDb, expectedIndexName); 254 } 255 256 indexListCursorExpected.close(); 257 indexListCursorActual.close(); 258 } 259 } 260 assertIndexInfoEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName)261 private static void assertIndexInfoEqual( 262 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName) { 263 Cursor indexInfoCursorExpected = 264 expectedDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null); 265 Cursor indexInfoCursorActual = 266 actualDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null); 267 assertEquals( 268 "Index columns count mismatch for " + indexName, 269 indexInfoCursorExpected.getCount(), 270 indexInfoCursorActual.getCount()); 271 272 while (indexInfoCursorExpected.moveToNext() && indexInfoCursorActual.moveToNext()) { 273 assertEquals( 274 "Index info mismatch for " + indexName, 275 indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("seqno")), 276 indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("seqno"))); 277 assertEquals( 278 "Index info mismatch for " + indexName, 279 indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("cid")), 280 indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("cid"))); 281 assertEquals( 282 "Index info mismatch for " + indexName, 283 indexInfoCursorExpected.getString( 284 indexInfoCursorExpected.getColumnIndex("name")), 285 indexInfoCursorActual.getString(indexInfoCursorActual.getColumnIndex("name"))); 286 } 287 288 indexInfoCursorExpected.close(); 289 indexInfoCursorActual.close(); 290 } 291 } 292