1 //! Prepared statements cache for faster execution.
2 
3 use crate::raw_statement::RawStatement;
4 use crate::{Connection, Result, Statement};
5 use hashlink::LruCache;
6 use std::cell::RefCell;
7 use std::ops::{Deref, DerefMut};
8 use std::sync::Arc;
9 
10 impl Connection {
11     /// Prepare a SQL statement for execution, returning a previously prepared
12     /// (but not currently in-use) statement if one is available. The
13     /// returned statement will be cached for reuse by future calls to
14     /// `prepare_cached` once it is dropped.
15     ///
16     /// ```rust,no_run
17     /// # use rusqlite::{Connection, Result};
18     /// fn insert_new_people(conn: &Connection) -> Result<()> {
19     ///     {
20     ///         let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?)")?;
21     ///         stmt.execute(&["Joe Smith"])?;
22     ///     }
23     ///     {
24     ///         // This will return the same underlying SQLite statement handle without
25     ///         // having to prepare it again.
26     ///         let mut stmt = conn.prepare_cached("INSERT INTO People (name) VALUES (?)")?;
27     ///         stmt.execute(&["Bob Jones"])?;
28     ///     }
29     ///     Ok(())
30     /// }
31     /// ```
32     ///
33     /// # Failure
34     ///
35     /// Will return `Err` if `sql` cannot be converted to a C-compatible string
36     /// or if the underlying SQLite call fails.
prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>>37     pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
38         self.cache.get(self, sql)
39     }
40 
41     /// Set the maximum number of cached prepared statements this connection
42     /// will hold. By default, a connection will hold a relatively small
43     /// number of cached statements. If you need more, or know that you
44     /// will not use cached statements, you
45     /// can set the capacity manually using this method.
set_prepared_statement_cache_capacity(&self, capacity: usize)46     pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
47         self.cache.set_capacity(capacity)
48     }
49 
50     /// Remove/finalize all prepared statements currently in the cache.
flush_prepared_statement_cache(&self)51     pub fn flush_prepared_statement_cache(&self) {
52         self.cache.flush()
53     }
54 }
55 
56 /// Prepared statements LRU cache.
57 // #[derive(Debug)] // FIXME: https://github.com/kyren/hashlink/pull/4
58 pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
59 
60 /// Cacheable statement.
61 ///
62 /// Statement will return automatically to the cache by default.
63 /// If you want the statement to be discarded, call `discard()` on it.
64 pub struct CachedStatement<'conn> {
65     stmt: Option<Statement<'conn>>,
66     cache: &'conn StatementCache,
67 }
68 
69 impl<'conn> Deref for CachedStatement<'conn> {
70     type Target = Statement<'conn>;
71 
deref(&self) -> &Statement<'conn>72     fn deref(&self) -> &Statement<'conn> {
73         self.stmt.as_ref().unwrap()
74     }
75 }
76 
77 impl<'conn> DerefMut for CachedStatement<'conn> {
deref_mut(&mut self) -> &mut Statement<'conn>78     fn deref_mut(&mut self) -> &mut Statement<'conn> {
79         self.stmt.as_mut().unwrap()
80     }
81 }
82 
83 impl Drop for CachedStatement<'_> {
84     #[allow(unused_must_use)]
drop(&mut self)85     fn drop(&mut self) {
86         if let Some(stmt) = self.stmt.take() {
87             self.cache.cache_stmt(unsafe { stmt.into_raw() });
88         }
89     }
90 }
91 
92 impl CachedStatement<'_> {
new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn>93     fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
94         CachedStatement {
95             stmt: Some(stmt),
96             cache,
97         }
98     }
99 
100     /// Discard the statement, preventing it from being returned to its
101     /// `Connection`'s collection of cached statements.
discard(mut self)102     pub fn discard(mut self) {
103         self.stmt = None;
104     }
105 }
106 
107 impl StatementCache {
108     /// Create a statement cache.
with_capacity(capacity: usize) -> StatementCache109     pub fn with_capacity(capacity: usize) -> StatementCache {
110         StatementCache(RefCell::new(LruCache::new(capacity)))
111     }
112 
set_capacity(&self, capacity: usize)113     fn set_capacity(&self, capacity: usize) {
114         self.0.borrow_mut().set_capacity(capacity)
115     }
116 
117     // Search the cache for a prepared-statement object that implements `sql`.
118     // If no such prepared-statement can be found, allocate and prepare a new one.
119     //
120     // # Failure
121     //
122     // Will return `Err` if no cached statement can be found and the underlying
123     // SQLite prepare call fails.
get<'conn>( &'conn self, conn: &'conn Connection, sql: &str, ) -> Result<CachedStatement<'conn>>124     fn get<'conn>(
125         &'conn self,
126         conn: &'conn Connection,
127         sql: &str,
128     ) -> Result<CachedStatement<'conn>> {
129         let trimmed = sql.trim();
130         let mut cache = self.0.borrow_mut();
131         let stmt = match cache.remove(trimmed) {
132             Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
133             None => conn.prepare(trimmed),
134         };
135         stmt.map(|mut stmt| {
136             stmt.stmt.set_statement_cache_key(trimmed);
137             CachedStatement::new(stmt, self)
138         })
139     }
140 
141     // Return a statement to the cache.
cache_stmt(&self, stmt: RawStatement)142     fn cache_stmt(&self, stmt: RawStatement) {
143         if stmt.is_null() {
144             return;
145         }
146         let mut cache = self.0.borrow_mut();
147         stmt.clear_bindings();
148         if let Some(sql) = stmt.statement_cache_key() {
149             cache.insert(sql, stmt);
150         } else {
151             debug_assert!(
152                 false,
153                 "bug in statement cache code, statement returned to cache that without key"
154             );
155         }
156     }
157 
flush(&self)158     fn flush(&self) {
159         let mut cache = self.0.borrow_mut();
160         cache.clear()
161     }
162 }
163 
164 #[cfg(test)]
165 mod test {
166     use super::StatementCache;
167     use crate::{Connection, NO_PARAMS};
168     use fallible_iterator::FallibleIterator;
169 
170     impl StatementCache {
clear(&self)171         fn clear(&self) {
172             self.0.borrow_mut().clear();
173         }
174 
len(&self) -> usize175         fn len(&self) -> usize {
176             self.0.borrow().len()
177         }
178 
capacity(&self) -> usize179         fn capacity(&self) -> usize {
180             self.0.borrow().capacity()
181         }
182     }
183 
184     #[test]
test_cache()185     fn test_cache() {
186         let db = Connection::open_in_memory().unwrap();
187         let cache = &db.cache;
188         let initial_capacity = cache.capacity();
189         assert_eq!(0, cache.len());
190         assert!(initial_capacity > 0);
191 
192         let sql = "PRAGMA schema_version";
193         {
194             let mut stmt = db.prepare_cached(sql).unwrap();
195             assert_eq!(0, cache.len());
196             assert_eq!(
197                 0,
198                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
199             );
200         }
201         assert_eq!(1, cache.len());
202 
203         {
204             let mut stmt = db.prepare_cached(sql).unwrap();
205             assert_eq!(0, cache.len());
206             assert_eq!(
207                 0,
208                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
209             );
210         }
211         assert_eq!(1, cache.len());
212 
213         cache.clear();
214         assert_eq!(0, cache.len());
215         assert_eq!(initial_capacity, cache.capacity());
216     }
217 
218     #[test]
test_set_capacity()219     fn test_set_capacity() {
220         let db = Connection::open_in_memory().unwrap();
221         let cache = &db.cache;
222 
223         let sql = "PRAGMA schema_version";
224         {
225             let mut stmt = db.prepare_cached(sql).unwrap();
226             assert_eq!(0, cache.len());
227             assert_eq!(
228                 0,
229                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
230             );
231         }
232         assert_eq!(1, cache.len());
233 
234         db.set_prepared_statement_cache_capacity(0);
235         assert_eq!(0, cache.len());
236 
237         {
238             let mut stmt = db.prepare_cached(sql).unwrap();
239             assert_eq!(0, cache.len());
240             assert_eq!(
241                 0,
242                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
243             );
244         }
245         assert_eq!(0, cache.len());
246 
247         db.set_prepared_statement_cache_capacity(8);
248         {
249             let mut stmt = db.prepare_cached(sql).unwrap();
250             assert_eq!(0, cache.len());
251             assert_eq!(
252                 0,
253                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
254             );
255         }
256         assert_eq!(1, cache.len());
257     }
258 
259     #[test]
test_discard()260     fn test_discard() {
261         let db = Connection::open_in_memory().unwrap();
262         let cache = &db.cache;
263 
264         let sql = "PRAGMA schema_version";
265         {
266             let mut stmt = db.prepare_cached(sql).unwrap();
267             assert_eq!(0, cache.len());
268             assert_eq!(
269                 0,
270                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
271             );
272             stmt.discard();
273         }
274         assert_eq!(0, cache.len());
275     }
276 
277     #[test]
test_ddl()278     fn test_ddl() {
279         let db = Connection::open_in_memory().unwrap();
280         db.execute_batch(
281             r#"
282             CREATE TABLE foo (x INT);
283             INSERT INTO foo VALUES (1);
284         "#,
285         )
286         .unwrap();
287 
288         let sql = "SELECT * FROM foo";
289 
290         {
291             let mut stmt = db.prepare_cached(sql).unwrap();
292             assert_eq!(
293                 Ok(Some(1i32)),
294                 stmt.query(NO_PARAMS).unwrap().map(|r| r.get(0)).next()
295             );
296         }
297 
298         db.execute_batch(
299             r#"
300             ALTER TABLE foo ADD COLUMN y INT;
301             UPDATE foo SET y = 2;
302         "#,
303         )
304         .unwrap();
305 
306         {
307             let mut stmt = db.prepare_cached(sql).unwrap();
308             assert_eq!(
309                 Ok(Some((1i32, 2i32))),
310                 stmt.query(NO_PARAMS)
311                     .unwrap()
312                     .map(|r| Ok((r.get(0)?, r.get(1)?)))
313                     .next()
314             );
315         }
316     }
317 
318     #[test]
test_connection_close()319     fn test_connection_close() {
320         let conn = Connection::open_in_memory().unwrap();
321         conn.prepare_cached("SELECT * FROM sqlite_master;").unwrap();
322 
323         conn.close().expect("connection not closed");
324     }
325 
326     #[test]
test_cache_key()327     fn test_cache_key() {
328         let db = Connection::open_in_memory().unwrap();
329         let cache = &db.cache;
330         assert_eq!(0, cache.len());
331 
332         //let sql = " PRAGMA schema_version; -- comment";
333         let sql = "PRAGMA schema_version; ";
334         {
335             let mut stmt = db.prepare_cached(sql).unwrap();
336             assert_eq!(0, cache.len());
337             assert_eq!(
338                 0,
339                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
340             );
341         }
342         assert_eq!(1, cache.len());
343 
344         {
345             let mut stmt = db.prepare_cached(sql).unwrap();
346             assert_eq!(0, cache.len());
347             assert_eq!(
348                 0,
349                 stmt.query_row(NO_PARAMS, |r| r.get::<_, i64>(0)).unwrap()
350             );
351         }
352         assert_eq!(1, cache.len());
353     }
354 
355     #[test]
test_empty_stmt()356     fn test_empty_stmt() {
357         let conn = Connection::open_in_memory().unwrap();
358         conn.prepare_cached("").unwrap();
359     }
360 }
361