1#-*- coding: iso-8859-1 -*- 2# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks 3# 4# Copyright (C) 2006-2007 Gerhard H�ring <gh@ghaering.de> 5# 6# This file is part of pysqlite. 7# 8# This software is provided 'as-is', without any express or implied 9# warranty. In no event will the authors be held liable for any damages 10# arising from the use of this software. 11# 12# Permission is granted to anyone to use this software for any purpose, 13# including commercial applications, and to alter it and redistribute it 14# freely, subject to the following restrictions: 15# 16# 1. The origin of this software must not be misrepresented; you must not 17# claim that you wrote the original software. If you use this software 18# in a product, an acknowledgment in the product documentation would be 19# appreciated but is not required. 20# 2. Altered source versions must be plainly marked as such, and must not be 21# misrepresented as being the original software. 22# 3. This notice may not be removed or altered from any source distribution. 23 24import unittest 25import sqlite3 as sqlite 26 27class CollationTests(unittest.TestCase): 28 def CheckCreateCollationNotString(self): 29 con = sqlite.connect(":memory:") 30 with self.assertRaises(TypeError): 31 con.create_collation(None, lambda x, y: (x > y) - (x < y)) 32 33 def CheckCreateCollationNotCallable(self): 34 con = sqlite.connect(":memory:") 35 with self.assertRaises(TypeError) as cm: 36 con.create_collation("X", 42) 37 self.assertEqual(str(cm.exception), 'parameter must be callable') 38 39 def CheckCreateCollationNotAscii(self): 40 con = sqlite.connect(":memory:") 41 with self.assertRaises(sqlite.ProgrammingError): 42 con.create_collation("coll�", lambda x, y: (x > y) - (x < y)) 43 44 def CheckCreateCollationBadUpper(self): 45 class BadUpperStr(str): 46 def upper(self): 47 return None 48 con = sqlite.connect(":memory:") 49 mycoll = lambda x, y: -((x > y) - (x < y)) 50 con.create_collation(BadUpperStr("mycoll"), mycoll) 51 result = con.execute(""" 52 select x from ( 53 select 'a' as x 54 union 55 select 'b' as x 56 ) order by x collate mycoll 57 """).fetchall() 58 self.assertEqual(result[0][0], 'b') 59 self.assertEqual(result[1][0], 'a') 60 61 @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1), 62 'old SQLite versions crash on this test') 63 def CheckCollationIsUsed(self): 64 def mycoll(x, y): 65 # reverse order 66 return -((x > y) - (x < y)) 67 68 con = sqlite.connect(":memory:") 69 con.create_collation("mycoll", mycoll) 70 sql = """ 71 select x from ( 72 select 'a' as x 73 union 74 select 'b' as x 75 union 76 select 'c' as x 77 ) order by x collate mycoll 78 """ 79 result = con.execute(sql).fetchall() 80 self.assertEqual(result, [('c',), ('b',), ('a',)], 81 msg='the expected order was not returned') 82 83 con.create_collation("mycoll", None) 84 with self.assertRaises(sqlite.OperationalError) as cm: 85 result = con.execute(sql).fetchall() 86 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 87 88 def CheckCollationReturnsLargeInteger(self): 89 def mycoll(x, y): 90 # reverse order 91 return -((x > y) - (x < y)) * 2**32 92 con = sqlite.connect(":memory:") 93 con.create_collation("mycoll", mycoll) 94 sql = """ 95 select x from ( 96 select 'a' as x 97 union 98 select 'b' as x 99 union 100 select 'c' as x 101 ) order by x collate mycoll 102 """ 103 result = con.execute(sql).fetchall() 104 self.assertEqual(result, [('c',), ('b',), ('a',)], 105 msg="the expected order was not returned") 106 107 def CheckCollationRegisterTwice(self): 108 """ 109 Register two different collation functions under the same name. 110 Verify that the last one is actually used. 111 """ 112 con = sqlite.connect(":memory:") 113 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 114 con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) 115 result = con.execute(""" 116 select x from (select 'a' as x union select 'b' as x) order by x collate mycoll 117 """).fetchall() 118 self.assertEqual(result[0][0], 'b') 119 self.assertEqual(result[1][0], 'a') 120 121 def CheckDeregisterCollation(self): 122 """ 123 Register a collation, then deregister it. Make sure an error is raised if we try 124 to use it. 125 """ 126 con = sqlite.connect(":memory:") 127 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 128 con.create_collation("mycoll", None) 129 with self.assertRaises(sqlite.OperationalError) as cm: 130 con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") 131 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 132 133class ProgressTests(unittest.TestCase): 134 def CheckProgressHandlerUsed(self): 135 """ 136 Test that the progress handler is invoked once it is set. 137 """ 138 con = sqlite.connect(":memory:") 139 progress_calls = [] 140 def progress(): 141 progress_calls.append(None) 142 return 0 143 con.set_progress_handler(progress, 1) 144 con.execute(""" 145 create table foo(a, b) 146 """) 147 self.assertTrue(progress_calls) 148 149 150 def CheckOpcodeCount(self): 151 """ 152 Test that the opcode argument is respected. 153 """ 154 con = sqlite.connect(":memory:") 155 progress_calls = [] 156 def progress(): 157 progress_calls.append(None) 158 return 0 159 con.set_progress_handler(progress, 1) 160 curs = con.cursor() 161 curs.execute(""" 162 create table foo (a, b) 163 """) 164 first_count = len(progress_calls) 165 progress_calls = [] 166 con.set_progress_handler(progress, 2) 167 curs.execute(""" 168 create table bar (a, b) 169 """) 170 second_count = len(progress_calls) 171 self.assertGreaterEqual(first_count, second_count) 172 173 def CheckCancelOperation(self): 174 """ 175 Test that returning a non-zero value stops the operation in progress. 176 """ 177 con = sqlite.connect(":memory:") 178 progress_calls = [] 179 def progress(): 180 progress_calls.append(None) 181 return 1 182 con.set_progress_handler(progress, 1) 183 curs = con.cursor() 184 self.assertRaises( 185 sqlite.OperationalError, 186 curs.execute, 187 "create table bar (a, b)") 188 189 def CheckClearHandler(self): 190 """ 191 Test that setting the progress handler to None clears the previously set handler. 192 """ 193 con = sqlite.connect(":memory:") 194 action = 0 195 def progress(): 196 nonlocal action 197 action = 1 198 return 0 199 con.set_progress_handler(progress, 1) 200 con.set_progress_handler(None, 1) 201 con.execute("select 1 union select 2 union select 3").fetchall() 202 self.assertEqual(action, 0, "progress handler was not cleared") 203 204class TraceCallbackTests(unittest.TestCase): 205 def CheckTraceCallbackUsed(self): 206 """ 207 Test that the trace callback is invoked once it is set. 208 """ 209 con = sqlite.connect(":memory:") 210 traced_statements = [] 211 def trace(statement): 212 traced_statements.append(statement) 213 con.set_trace_callback(trace) 214 con.execute("create table foo(a, b)") 215 self.assertTrue(traced_statements) 216 self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) 217 218 def CheckClearTraceCallback(self): 219 """ 220 Test that setting the trace callback to None clears the previously set callback. 221 """ 222 con = sqlite.connect(":memory:") 223 traced_statements = [] 224 def trace(statement): 225 traced_statements.append(statement) 226 con.set_trace_callback(trace) 227 con.set_trace_callback(None) 228 con.execute("create table foo(a, b)") 229 self.assertFalse(traced_statements, "trace callback was not cleared") 230 231 def CheckUnicodeContent(self): 232 """ 233 Test that the statement can contain unicode literals. 234 """ 235 unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' 236 con = sqlite.connect(":memory:") 237 traced_statements = [] 238 def trace(statement): 239 traced_statements.append(statement) 240 con.set_trace_callback(trace) 241 con.execute("create table foo(x)") 242 # Can't execute bound parameters as their values don't appear 243 # in traced statements before SQLite 3.6.21 244 # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html) 245 con.execute('insert into foo(x) values ("%s")' % unicode_value) 246 con.commit() 247 self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), 248 "Unicode data %s garbled in trace callback: %s" 249 % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) 250 251 252 253def suite(): 254 collation_suite = unittest.makeSuite(CollationTests, "Check") 255 progress_suite = unittest.makeSuite(ProgressTests, "Check") 256 trace_suite = unittest.makeSuite(TraceCallbackTests, "Check") 257 return unittest.TestSuite((collation_suite, progress_suite, trace_suite)) 258 259def test(): 260 runner = unittest.TextTestRunner() 261 runner.run(suite()) 262 263if __name__ == "__main__": 264 test() 265