1 // Copyright 2021, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use anyhow::{anyhow, Context, Result};
16 use rusqlite::{params, OptionalExtension, Transaction, NO_PARAMS};
17 
create_or_get_version(tx: &Transaction, current_version: u32) -> Result<u32>18 pub fn create_or_get_version(tx: &Transaction, current_version: u32) -> Result<u32> {
19     tx.execute(
20         "CREATE TABLE IF NOT EXISTS persistent.version (
21                 id INTEGER PRIMARY KEY,
22                 version INTEGER);",
23         NO_PARAMS,
24     )
25     .context("In create_or_get_version: Failed to create version table.")?;
26 
27     let version = tx
28         .query_row("SELECT version FROM persistent.version WHERE id = 0;", NO_PARAMS, |row| {
29             row.get(0)
30         })
31         .optional()
32         .context("In create_or_get_version: Failed to read version.")?;
33 
34     let version = if let Some(version) = version {
35         version
36     } else {
37         // If no version table existed it could mean one of two things:
38         // 1) This database is completely new. In this case the version has to be set
39         //    to the current version and the current version which also needs to be
40         //    returned.
41         // 2) The database predates db versioning. In this case the version needs to be
42         //    set to 0, and 0 needs to be returned.
43         let version = if tx
44             .query_row(
45                 "SELECT name FROM persistent.sqlite_master
46                  WHERE type = 'table' AND name = 'keyentry';",
47                 NO_PARAMS,
48                 |_| Ok(()),
49             )
50             .optional()
51             .context("In create_or_get_version: Failed to check for keyentry table.")?
52             .is_none()
53         {
54             current_version
55         } else {
56             0
57         };
58 
59         tx.execute("INSERT INTO persistent.version (id, version) VALUES(0, ?);", params![version])
60             .context("In create_or_get_version: Failed to insert initial version.")?;
61         version
62     };
63     Ok(version)
64 }
65 
update_version(tx: &Transaction, new_version: u32) -> Result<()>66 pub fn update_version(tx: &Transaction, new_version: u32) -> Result<()> {
67     let updated = tx
68         .execute("UPDATE persistent.version SET version = ? WHERE id = 0;", params![new_version])
69         .context("In update_version: Failed to update row.")?;
70     if updated == 1 {
71         Ok(())
72     } else {
73         Err(anyhow!("In update_version: No rows were updated."))
74     }
75 }
76 
upgrade_database<F>(tx: &Transaction, current_version: u32, upgraders: &[F]) -> Result<()> where F: Fn(&Transaction) -> Result<u32> + 'static,77 pub fn upgrade_database<F>(tx: &Transaction, current_version: u32, upgraders: &[F]) -> Result<()>
78 where
79     F: Fn(&Transaction) -> Result<u32> + 'static,
80 {
81     if upgraders.len() < current_version as usize {
82         return Err(anyhow!("In upgrade_database: Insufficient upgraders provided."));
83     }
84     let mut db_version = create_or_get_version(tx, current_version)
85         .context("In upgrade_database: Failed to get database version.")?;
86     while db_version < current_version {
87         db_version = upgraders[db_version as usize](tx).with_context(|| {
88             format!("In upgrade_database: Trying to upgrade from db version {}.", db_version)
89         })?;
90     }
91     update_version(tx, db_version).context("In upgrade_database.")
92 }
93 
94 #[cfg(test)]
95 mod test {
96     use super::*;
97     use rusqlite::{Connection, TransactionBehavior, NO_PARAMS};
98 
99     #[test]
upgrade_database_test()100     fn upgrade_database_test() {
101         let mut conn = Connection::open_in_memory().unwrap();
102         conn.execute("ATTACH DATABASE 'file::memory:' as persistent;", NO_PARAMS).unwrap();
103 
104         let upgraders: Vec<_> = (0..30_u32)
105             .map(move |i| {
106                 move |tx: &Transaction| {
107                     tx.execute(
108                         "INSERT INTO persistent.test (test_field) VALUES(?);",
109                         params![i + 1],
110                     )
111                     .with_context(|| format!("In upgrade_from_{}_to_{}.", i, i + 1))?;
112                     Ok(i + 1)
113                 }
114             })
115             .collect();
116 
117         for legacy in &[false, true] {
118             if *legacy {
119                 conn.execute(
120                     "CREATE TABLE IF NOT EXISTS persistent.keyentry (
121                         id INTEGER UNIQUE,
122                         key_type INTEGER,
123                         domain INTEGER,
124                         namespace INTEGER,
125                         alias BLOB,
126                         state INTEGER,
127                         km_uuid BLOB);",
128                     NO_PARAMS,
129                 )
130                 .unwrap();
131             }
132             for from in 1..29 {
133                 for to in from..30 {
134                     conn.execute("DROP TABLE IF EXISTS persistent.version;", NO_PARAMS).unwrap();
135                     conn.execute("DROP TABLE IF EXISTS persistent.test;", NO_PARAMS).unwrap();
136                     conn.execute(
137                         "CREATE TABLE IF NOT EXISTS persistent.test (
138                             id INTEGER PRIMARY KEY,
139                             test_field INTEGER);",
140                         NO_PARAMS,
141                     )
142                     .unwrap();
143 
144                     {
145                         let tx =
146                             conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
147                         create_or_get_version(&tx, from).unwrap();
148                         tx.commit().unwrap();
149                     }
150                     {
151                         let tx =
152                             conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
153                         upgrade_database(&tx, to, &upgraders).unwrap();
154                         tx.commit().unwrap();
155                     }
156 
157                     // In the legacy database case all upgraders starting from 0 have to run. So
158                     // after the upgrade step, the expectations need to be adjusted.
159                     let from = if *legacy { 0 } else { from };
160 
161                     // There must be exactly to - from rows.
162                     assert_eq!(
163                         to - from,
164                         conn.query_row(
165                             "SELECT COUNT(test_field) FROM persistent.test;",
166                             NO_PARAMS,
167                             |row| row.get(0)
168                         )
169                         .unwrap()
170                     );
171                     // Each row must have the correct relation between id and test_field. If this
172                     // is not the case, the upgraders were not executed in the correct order.
173                     assert_eq!(
174                         to - from,
175                         conn.query_row(
176                             "SELECT COUNT(test_field) FROM persistent.test
177                              WHERE id = test_field - ?;",
178                             params![from],
179                             |row| row.get(0)
180                         )
181                         .unwrap()
182                     );
183                 }
184             }
185         }
186     }
187 
188     #[test]
create_or_get_version_new_database()189     fn create_or_get_version_new_database() {
190         let mut conn = Connection::open_in_memory().unwrap();
191         conn.execute("ATTACH DATABASE 'file::memory:' as persistent;", NO_PARAMS).unwrap();
192         {
193             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
194             let version = create_or_get_version(&tx, 3).unwrap();
195             tx.commit().unwrap();
196             assert_eq!(version, 3);
197         }
198 
199         // Was the version table created as expected?
200         assert_eq!(
201             Ok("version".to_owned()),
202             conn.query_row(
203                 "SELECT name FROM persistent.sqlite_master
204                  WHERE type = 'table' AND name = 'version';",
205                 NO_PARAMS,
206                 |row| row.get(0),
207             )
208         );
209 
210         // There is exactly one row in the version table.
211         assert_eq!(
212             Ok(1),
213             conn.query_row("SELECT COUNT(id) from persistent.version;", NO_PARAMS, |row| row
214                 .get(0))
215         );
216 
217         // The version must be set to 3
218         assert_eq!(
219             Ok(3),
220             conn.query_row(
221                 "SELECT version from persistent.version WHERE id = 0;",
222                 NO_PARAMS,
223                 |row| row.get(0)
224             )
225         );
226 
227         // Will subsequent calls to create_or_get_version still return the same version even
228         // if the current version changes.
229         {
230             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
231             let version = create_or_get_version(&tx, 5).unwrap();
232             tx.commit().unwrap();
233             assert_eq!(version, 3);
234         }
235 
236         // There is still exactly one row in the version table.
237         assert_eq!(
238             Ok(1),
239             conn.query_row("SELECT COUNT(id) from persistent.version;", NO_PARAMS, |row| row
240                 .get(0))
241         );
242 
243         // Bump the version.
244         {
245             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
246             update_version(&tx, 5).unwrap();
247             tx.commit().unwrap();
248         }
249 
250         // Now the version should have changed.
251         {
252             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
253             let version = create_or_get_version(&tx, 7).unwrap();
254             tx.commit().unwrap();
255             assert_eq!(version, 5);
256         }
257 
258         // There is still exactly one row in the version table.
259         assert_eq!(
260             Ok(1),
261             conn.query_row("SELECT COUNT(id) from persistent.version;", NO_PARAMS, |row| row
262                 .get(0))
263         );
264 
265         // The version must be set to 5
266         assert_eq!(
267             Ok(5),
268             conn.query_row(
269                 "SELECT version from persistent.version WHERE id = 0;",
270                 NO_PARAMS,
271                 |row| row.get(0)
272             )
273         );
274     }
275 
276     #[test]
create_or_get_version_legacy_database()277     fn create_or_get_version_legacy_database() {
278         let mut conn = Connection::open_in_memory().unwrap();
279         conn.execute("ATTACH DATABASE 'file::memory:' as persistent;", NO_PARAMS).unwrap();
280         // A legacy (version 0) database is detected if the keyentry table exists but no
281         // version table.
282         conn.execute(
283             "CREATE TABLE IF NOT EXISTS persistent.keyentry (
284              id INTEGER UNIQUE,
285              key_type INTEGER,
286              domain INTEGER,
287              namespace INTEGER,
288              alias BLOB,
289              state INTEGER,
290              km_uuid BLOB);",
291             NO_PARAMS,
292         )
293         .unwrap();
294 
295         {
296             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
297             let version = create_or_get_version(&tx, 3).unwrap();
298             tx.commit().unwrap();
299             // In the legacy case, version 0 must be returned.
300             assert_eq!(version, 0);
301         }
302 
303         // Was the version table created as expected?
304         assert_eq!(
305             Ok("version".to_owned()),
306             conn.query_row(
307                 "SELECT name FROM persistent.sqlite_master
308                  WHERE type = 'table' AND name = 'version';",
309                 NO_PARAMS,
310                 |row| row.get(0),
311             )
312         );
313 
314         // There is exactly one row in the version table.
315         assert_eq!(
316             Ok(1),
317             conn.query_row("SELECT COUNT(id) from persistent.version;", NO_PARAMS, |row| row
318                 .get(0))
319         );
320 
321         // The version must be set to 0
322         assert_eq!(
323             Ok(0),
324             conn.query_row(
325                 "SELECT version from persistent.version WHERE id = 0;",
326                 NO_PARAMS,
327                 |row| row.get(0)
328             )
329         );
330 
331         // Will subsequent calls to create_or_get_version still return the same version even
332         // if the current version changes.
333         {
334             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
335             let version = create_or_get_version(&tx, 5).unwrap();
336             tx.commit().unwrap();
337             assert_eq!(version, 0);
338         }
339 
340         // There is still exactly one row in the version table.
341         assert_eq!(
342             Ok(1),
343             conn.query_row("SELECT COUNT(id) from persistent.version;", NO_PARAMS, |row| row
344                 .get(0))
345         );
346 
347         // Bump the version.
348         {
349             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
350             update_version(&tx, 5).unwrap();
351             tx.commit().unwrap();
352         }
353 
354         // Now the version should have changed.
355         {
356             let tx = conn.transaction_with_behavior(TransactionBehavior::Immediate).unwrap();
357             let version = create_or_get_version(&tx, 7).unwrap();
358             tx.commit().unwrap();
359             assert_eq!(version, 5);
360         }
361 
362         // There is still exactly one row in the version table.
363         assert_eq!(
364             Ok(1),
365             conn.query_row("SELECT COUNT(id) from persistent.version;", NO_PARAMS, |row| row
366                 .get(0))
367         );
368 
369         // The version must be set to 5
370         assert_eq!(
371             Ok(5),
372             conn.query_row(
373                 "SELECT version from persistent.version WHERE id = 0;",
374                 NO_PARAMS,
375                 |row| row.get(0)
376             )
377         );
378     }
379 }
380