1 # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
2 #
3 # Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>
4 #
5 # This file is part of pysqlite.
6 #
7 # This software is provided 'as-is', without any express or implied
8 # warranty. In no event will the authors be held liable for any damages
9 # arising from the use of this software.
10 #
11 # Permission is granted to anyone to use this software for any purpose,
12 # including commercial applications, and to alter it and redistribute it
13 # freely, subject to the following restrictions:
14 #
15 # 1. The origin of this software must not be misrepresented; you must not
16 # claim that you wrote the original software. If you use this software
17 # in a product, an acknowledgment in the product documentation would be
18 # appreciated but is not required.
19 # 2. Altered source versions must be plainly marked as such, and must not be
20 # misrepresented as being the original software.
21 # 3. This notice may not be removed or altered from any source distribution.
22
23 import contextlib
24 import sqlite3 as sqlite
25 import unittest
26
27 from test.support.os_helper import TESTFN, unlink
28
29 from test.test_sqlite3.test_dbapi import memory_database, cx_limit
30 from test.test_sqlite3.test_userfunctions import with_tracebacks
31
32
33 class ESC[4;38;5;81mCollationTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
34 def test_create_collation_not_string(self):
35 con = sqlite.connect(":memory:")
36 with self.assertRaises(TypeError):
37 con.create_collation(None, lambda x, y: (x > y) - (x < y))
38
39 def test_create_collation_not_callable(self):
40 con = sqlite.connect(":memory:")
41 with self.assertRaises(TypeError) as cm:
42 con.create_collation("X", 42)
43 self.assertEqual(str(cm.exception), 'parameter must be callable')
44
45 def test_create_collation_not_ascii(self):
46 con = sqlite.connect(":memory:")
47 con.create_collation("collä", lambda x, y: (x > y) - (x < y))
48
49 def test_create_collation_bad_upper(self):
50 class ESC[4;38;5;81mBadUpperStr(ESC[4;38;5;149mstr):
51 def upper(self):
52 return None
53 con = sqlite.connect(":memory:")
54 mycoll = lambda x, y: -((x > y) - (x < y))
55 con.create_collation(BadUpperStr("mycoll"), mycoll)
56 result = con.execute("""
57 select x from (
58 select 'a' as x
59 union
60 select 'b' as x
61 ) order by x collate mycoll
62 """).fetchall()
63 self.assertEqual(result[0][0], 'b')
64 self.assertEqual(result[1][0], 'a')
65
66 def test_collation_is_used(self):
67 def mycoll(x, y):
68 # reverse order
69 return -((x > y) - (x < y))
70
71 con = sqlite.connect(":memory:")
72 con.create_collation("mycoll", mycoll)
73 sql = """
74 select x from (
75 select 'a' as x
76 union
77 select 'b' as x
78 union
79 select 'c' as x
80 ) order by x collate mycoll
81 """
82 result = con.execute(sql).fetchall()
83 self.assertEqual(result, [('c',), ('b',), ('a',)],
84 msg='the expected order was not returned')
85
86 con.create_collation("mycoll", None)
87 with self.assertRaises(sqlite.OperationalError) as cm:
88 result = con.execute(sql).fetchall()
89 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
90
91 def test_collation_returns_large_integer(self):
92 def mycoll(x, y):
93 # reverse order
94 return -((x > y) - (x < y)) * 2**32
95 con = sqlite.connect(":memory:")
96 con.create_collation("mycoll", mycoll)
97 sql = """
98 select x from (
99 select 'a' as x
100 union
101 select 'b' as x
102 union
103 select 'c' as x
104 ) order by x collate mycoll
105 """
106 result = con.execute(sql).fetchall()
107 self.assertEqual(result, [('c',), ('b',), ('a',)],
108 msg="the expected order was not returned")
109
110 def test_collation_register_twice(self):
111 """
112 Register two different collation functions under the same name.
113 Verify that the last one is actually used.
114 """
115 con = sqlite.connect(":memory:")
116 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
117 con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
118 result = con.execute("""
119 select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
120 """).fetchall()
121 self.assertEqual(result[0][0], 'b')
122 self.assertEqual(result[1][0], 'a')
123
124 def test_deregister_collation(self):
125 """
126 Register a collation, then deregister it. Make sure an error is raised if we try
127 to use it.
128 """
129 con = sqlite.connect(":memory:")
130 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
131 con.create_collation("mycoll", None)
132 with self.assertRaises(sqlite.OperationalError) as cm:
133 con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
134 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
135
136 class ESC[4;38;5;81mProgressTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
137 def test_progress_handler_used(self):
138 """
139 Test that the progress handler is invoked once it is set.
140 """
141 con = sqlite.connect(":memory:")
142 progress_calls = []
143 def progress():
144 progress_calls.append(None)
145 return 0
146 con.set_progress_handler(progress, 1)
147 con.execute("""
148 create table foo(a, b)
149 """)
150 self.assertTrue(progress_calls)
151
152 def test_opcode_count(self):
153 """
154 Test that the opcode argument is respected.
155 """
156 con = sqlite.connect(":memory:")
157 progress_calls = []
158 def progress():
159 progress_calls.append(None)
160 return 0
161 con.set_progress_handler(progress, 1)
162 curs = con.cursor()
163 curs.execute("""
164 create table foo (a, b)
165 """)
166 first_count = len(progress_calls)
167 progress_calls = []
168 con.set_progress_handler(progress, 2)
169 curs.execute("""
170 create table bar (a, b)
171 """)
172 second_count = len(progress_calls)
173 self.assertGreaterEqual(first_count, second_count)
174
175 def test_cancel_operation(self):
176 """
177 Test that returning a non-zero value stops the operation in progress.
178 """
179 con = sqlite.connect(":memory:")
180 def progress():
181 return 1
182 con.set_progress_handler(progress, 1)
183 curs = con.cursor()
184 self.assertRaises(
185 sqlite.OperationalError,
186 curs.execute,
187 "create table bar (a, b)")
188
189 def test_clear_handler(self):
190 """
191 Test that setting the progress handler to None clears the previously set handler.
192 """
193 con = sqlite.connect(":memory:")
194 action = 0
195 def progress():
196 nonlocal action
197 action = 1
198 return 0
199 con.set_progress_handler(progress, 1)
200 con.set_progress_handler(None, 1)
201 con.execute("select 1 union select 2 union select 3").fetchall()
202 self.assertEqual(action, 0, "progress handler was not cleared")
203
204 @with_tracebacks(ZeroDivisionError, name="bad_progress")
205 def test_error_in_progress_handler(self):
206 con = sqlite.connect(":memory:")
207 def bad_progress():
208 1 / 0
209 con.set_progress_handler(bad_progress, 1)
210 with self.assertRaises(sqlite.OperationalError):
211 con.execute("""
212 create table foo(a, b)
213 """)
214
215 @with_tracebacks(ZeroDivisionError, name="bad_progress")
216 def test_error_in_progress_handler_result(self):
217 con = sqlite.connect(":memory:")
218 class ESC[4;38;5;81mBadBool:
219 def __bool__(self):
220 1 / 0
221 def bad_progress():
222 return BadBool()
223 con.set_progress_handler(bad_progress, 1)
224 with self.assertRaises(sqlite.OperationalError):
225 con.execute("""
226 create table foo(a, b)
227 """)
228
229
230 class ESC[4;38;5;81mTraceCallbackTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
231 @contextlib.contextmanager
232 def check_stmt_trace(self, cx, expected):
233 try:
234 traced = []
235 cx.set_trace_callback(lambda stmt: traced.append(stmt))
236 yield
237 finally:
238 self.assertEqual(traced, expected)
239 cx.set_trace_callback(None)
240
241 def test_trace_callback_used(self):
242 """
243 Test that the trace callback is invoked once it is set.
244 """
245 con = sqlite.connect(":memory:")
246 traced_statements = []
247 def trace(statement):
248 traced_statements.append(statement)
249 con.set_trace_callback(trace)
250 con.execute("create table foo(a, b)")
251 self.assertTrue(traced_statements)
252 self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
253
254 def test_clear_trace_callback(self):
255 """
256 Test that setting the trace callback to None clears the previously set callback.
257 """
258 con = sqlite.connect(":memory:")
259 traced_statements = []
260 def trace(statement):
261 traced_statements.append(statement)
262 con.set_trace_callback(trace)
263 con.set_trace_callback(None)
264 con.execute("create table foo(a, b)")
265 self.assertFalse(traced_statements, "trace callback was not cleared")
266
267 def test_unicode_content(self):
268 """
269 Test that the statement can contain unicode literals.
270 """
271 unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
272 con = sqlite.connect(":memory:")
273 traced_statements = []
274 def trace(statement):
275 traced_statements.append(statement)
276 con.set_trace_callback(trace)
277 con.execute("create table foo(x)")
278 con.execute("insert into foo(x) values ('%s')" % unicode_value)
279 con.commit()
280 self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
281 "Unicode data %s garbled in trace callback: %s"
282 % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
283
284 def test_trace_callback_content(self):
285 # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
286 traced_statements = []
287 def trace(statement):
288 traced_statements.append(statement)
289
290 queries = ["create table foo(x)",
291 "insert into foo(x) values(1)"]
292 self.addCleanup(unlink, TESTFN)
293 con1 = sqlite.connect(TESTFN, isolation_level=None)
294 con2 = sqlite.connect(TESTFN)
295 try:
296 con1.set_trace_callback(trace)
297 cur = con1.cursor()
298 cur.execute(queries[0])
299 con2.execute("create table bar(x)")
300 cur.execute(queries[1])
301 finally:
302 con1.close()
303 con2.close()
304 self.assertEqual(traced_statements, queries)
305
306 def test_trace_expanded_sql(self):
307 expected = [
308 "create table t(t)",
309 "BEGIN ",
310 "insert into t values(0)",
311 "insert into t values(1)",
312 "insert into t values(2)",
313 "COMMIT",
314 ]
315 with memory_database() as cx, self.check_stmt_trace(cx, expected):
316 with cx:
317 cx.execute("create table t(t)")
318 cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
319
320 @with_tracebacks(
321 sqlite.DataError,
322 regex="Expanded SQL string exceeds the maximum string length"
323 )
324 def test_trace_too_much_expanded_sql(self):
325 # If the expanded string is too large, we'll fall back to the
326 # unexpanded SQL statement (for SQLite 3.14.0 and newer).
327 # The resulting string length is limited by the runtime limit
328 # SQLITE_LIMIT_LENGTH.
329 template = "select 1 as a where a="
330 category = sqlite.SQLITE_LIMIT_LENGTH
331 with memory_database() as cx, cx_limit(cx, category=category) as lim:
332 ok_param = "a"
333 bad_param = "a" * lim
334
335 unexpanded_query = template + "?"
336 expected = [unexpanded_query]
337 if sqlite.sqlite_version_info < (3, 14, 0):
338 expected = []
339 with self.check_stmt_trace(cx, expected):
340 cx.execute(unexpanded_query, (bad_param,))
341
342 expanded_query = f"{template}'{ok_param}'"
343 with self.check_stmt_trace(cx, [expanded_query]):
344 cx.execute(unexpanded_query, (ok_param,))
345
346 @with_tracebacks(ZeroDivisionError, regex="division by zero")
347 def test_trace_bad_handler(self):
348 with memory_database() as cx:
349 cx.set_trace_callback(lambda stmt: 5/0)
350 cx.execute("select 1")
351
352
353 if __name__ == "__main__":
354 unittest.main()