1 # pysqlite2/test/userfunctions.py: tests for user-defined functions and
2 # aggregates.
3 #
4 # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
5 #
6 # This file is part of pysqlite.
7 #
8 # This software is provided 'as-is', without any express or implied
9 # warranty. In no event will the authors be held liable for any damages
10 # arising from the use of this software.
11 #
12 # Permission is granted to anyone to use this software for any purpose,
13 # including commercial applications, and to alter it and redistribute it
14 # freely, subject to the following restrictions:
15 #
16 # 1. The origin of this software must not be misrepresented; you must not
17 # claim that you wrote the original software. If you use this software
18 # in a product, an acknowledgment in the product documentation would be
19 # appreciated but is not required.
20 # 2. Altered source versions must be plainly marked as such, and must not be
21 # misrepresented as being the original software.
22 # 3. This notice may not be removed or altered from any source distribution.
23
24 import contextlib
25 import functools
26 import io
27 import re
28 import sys
29 import unittest
30 import sqlite3 as sqlite
31
32 from unittest.mock import Mock, patch
33 from test.support import bigmemtest, catch_unraisable_exception, gc_collect
34
35 from test.test_sqlite3.test_dbapi import cx_limit
36
37
38 def with_tracebacks(exc, regex="", name=""):
39 """Convenience decorator for testing callback tracebacks."""
40 def decorator(func):
41 _regex = re.compile(regex) if regex else None
42 @functools.wraps(func)
43 def wrapper(self, *args, **kwargs):
44 with catch_unraisable_exception() as cm:
45 # First, run the test with traceback enabled.
46 with check_tracebacks(self, cm, exc, _regex, name):
47 func(self, *args, **kwargs)
48
49 # Then run the test with traceback disabled.
50 func(self, *args, **kwargs)
51 return wrapper
52 return decorator
53
54
55 @contextlib.contextmanager
56 def check_tracebacks(self, cm, exc, regex, obj_name):
57 """Convenience context manager for testing callback tracebacks."""
58 sqlite.enable_callback_tracebacks(True)
59 try:
60 buf = io.StringIO()
61 with contextlib.redirect_stderr(buf):
62 yield
63
64 self.assertEqual(cm.unraisable.exc_type, exc)
65 if regex:
66 msg = str(cm.unraisable.exc_value)
67 self.assertIsNotNone(regex.search(msg))
68 if obj_name:
69 self.assertEqual(cm.unraisable.object.__name__, obj_name)
70 finally:
71 sqlite.enable_callback_tracebacks(False)
72
73
74 def func_returntext():
75 return "foo"
76 def func_returntextwithnull():
77 return "1\x002"
78 def func_returnunicode():
79 return "bar"
80 def func_returnint():
81 return 42
82 def func_returnfloat():
83 return 3.14
84 def func_returnnull():
85 return None
86 def func_returnblob():
87 return b"blob"
88 def func_returnlonglong():
89 return 1<<31
90 def func_raiseexception():
91 5/0
92 def func_memoryerror():
93 raise MemoryError
94 def func_overflowerror():
95 raise OverflowError
96
97 class ESC[4;38;5;81mAggrNoStep:
98 def __init__(self):
99 pass
100
101 def finalize(self):
102 return 1
103
104 class ESC[4;38;5;81mAggrNoFinalize:
105 def __init__(self):
106 pass
107
108 def step(self, x):
109 pass
110
111 class ESC[4;38;5;81mAggrExceptionInInit:
112 def __init__(self):
113 5/0
114
115 def step(self, x):
116 pass
117
118 def finalize(self):
119 pass
120
121 class ESC[4;38;5;81mAggrExceptionInStep:
122 def __init__(self):
123 pass
124
125 def step(self, x):
126 5/0
127
128 def finalize(self):
129 return 42
130
131 class ESC[4;38;5;81mAggrExceptionInFinalize:
132 def __init__(self):
133 pass
134
135 def step(self, x):
136 pass
137
138 def finalize(self):
139 5/0
140
141 class ESC[4;38;5;81mAggrCheckType:
142 def __init__(self):
143 self.val = None
144
145 def step(self, whichType, val):
146 theType = {"str": str, "int": int, "float": float, "None": type(None),
147 "blob": bytes}
148 self.val = int(theType[whichType] is type(val))
149
150 def finalize(self):
151 return self.val
152
153 class ESC[4;38;5;81mAggrCheckTypes:
154 def __init__(self):
155 self.val = 0
156
157 def step(self, whichType, *vals):
158 theType = {"str": str, "int": int, "float": float, "None": type(None),
159 "blob": bytes}
160 for val in vals:
161 self.val += int(theType[whichType] is type(val))
162
163 def finalize(self):
164 return self.val
165
166 class ESC[4;38;5;81mAggrSum:
167 def __init__(self):
168 self.val = 0.0
169
170 def step(self, val):
171 self.val += val
172
173 def finalize(self):
174 return self.val
175
176 class ESC[4;38;5;81mAggrText:
177 def __init__(self):
178 self.txt = ""
179 def step(self, txt):
180 self.txt = self.txt + txt
181 def finalize(self):
182 return self.txt
183
184
185 class ESC[4;38;5;81mFunctionTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
186 def setUp(self):
187 self.con = sqlite.connect(":memory:")
188
189 self.con.create_function("returntext", 0, func_returntext)
190 self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
191 self.con.create_function("returnunicode", 0, func_returnunicode)
192 self.con.create_function("returnint", 0, func_returnint)
193 self.con.create_function("returnfloat", 0, func_returnfloat)
194 self.con.create_function("returnnull", 0, func_returnnull)
195 self.con.create_function("returnblob", 0, func_returnblob)
196 self.con.create_function("returnlonglong", 0, func_returnlonglong)
197 self.con.create_function("returnnan", 0, lambda: float("nan"))
198 self.con.create_function("return_noncont_blob", 0,
199 lambda: memoryview(b"blob")[::2])
200 self.con.create_function("raiseexception", 0, func_raiseexception)
201 self.con.create_function("memoryerror", 0, func_memoryerror)
202 self.con.create_function("overflowerror", 0, func_overflowerror)
203
204 self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes))
205 self.con.create_function("isnone", 1, lambda x: x is None)
206 self.con.create_function("spam", -1, lambda *x: len(x))
207 self.con.execute("create table test(t text)")
208
209 def tearDown(self):
210 self.con.close()
211
212 def test_func_error_on_create(self):
213 with self.assertRaises(sqlite.OperationalError):
214 self.con.create_function("bla", -100, lambda x: 2*x)
215
216 def test_func_too_many_args(self):
217 category = sqlite.SQLITE_LIMIT_FUNCTION_ARG
218 msg = "too many arguments on function"
219 with cx_limit(self.con, category=category, limit=1):
220 self.con.execute("select abs(-1)");
221 with self.assertRaisesRegex(sqlite.OperationalError, msg):
222 self.con.execute("select max(1, 2)");
223
224 def test_func_ref_count(self):
225 def getfunc():
226 def f():
227 return 1
228 return f
229 f = getfunc()
230 globals()["foo"] = f
231 # self.con.create_function("reftest", 0, getfunc())
232 self.con.create_function("reftest", 0, f)
233 cur = self.con.cursor()
234 cur.execute("select reftest()")
235
236 def test_func_return_text(self):
237 cur = self.con.cursor()
238 cur.execute("select returntext()")
239 val = cur.fetchone()[0]
240 self.assertEqual(type(val), str)
241 self.assertEqual(val, "foo")
242
243 def test_func_return_text_with_null_char(self):
244 cur = self.con.cursor()
245 res = cur.execute("select returntextwithnull()").fetchone()[0]
246 self.assertEqual(type(res), str)
247 self.assertEqual(res, "1\x002")
248
249 def test_func_return_unicode(self):
250 cur = self.con.cursor()
251 cur.execute("select returnunicode()")
252 val = cur.fetchone()[0]
253 self.assertEqual(type(val), str)
254 self.assertEqual(val, "bar")
255
256 def test_func_return_int(self):
257 cur = self.con.cursor()
258 cur.execute("select returnint()")
259 val = cur.fetchone()[0]
260 self.assertEqual(type(val), int)
261 self.assertEqual(val, 42)
262
263 def test_func_return_float(self):
264 cur = self.con.cursor()
265 cur.execute("select returnfloat()")
266 val = cur.fetchone()[0]
267 self.assertEqual(type(val), float)
268 if val < 3.139 or val > 3.141:
269 self.fail("wrong value")
270
271 def test_func_return_null(self):
272 cur = self.con.cursor()
273 cur.execute("select returnnull()")
274 val = cur.fetchone()[0]
275 self.assertEqual(type(val), type(None))
276 self.assertEqual(val, None)
277
278 def test_func_return_blob(self):
279 cur = self.con.cursor()
280 cur.execute("select returnblob()")
281 val = cur.fetchone()[0]
282 self.assertEqual(type(val), bytes)
283 self.assertEqual(val, b"blob")
284
285 def test_func_return_long_long(self):
286 cur = self.con.cursor()
287 cur.execute("select returnlonglong()")
288 val = cur.fetchone()[0]
289 self.assertEqual(val, 1<<31)
290
291 def test_func_return_nan(self):
292 cur = self.con.cursor()
293 cur.execute("select returnnan()")
294 self.assertIsNone(cur.fetchone()[0])
295
296 @with_tracebacks(ZeroDivisionError, name="func_raiseexception")
297 def test_func_exception(self):
298 cur = self.con.cursor()
299 with self.assertRaises(sqlite.OperationalError) as cm:
300 cur.execute("select raiseexception()")
301 cur.fetchone()
302 self.assertEqual(str(cm.exception), 'user-defined function raised exception')
303
304 @with_tracebacks(MemoryError, name="func_memoryerror")
305 def test_func_memory_error(self):
306 cur = self.con.cursor()
307 with self.assertRaises(MemoryError):
308 cur.execute("select memoryerror()")
309 cur.fetchone()
310
311 @with_tracebacks(OverflowError, name="func_overflowerror")
312 def test_func_overflow_error(self):
313 cur = self.con.cursor()
314 with self.assertRaises(sqlite.DataError):
315 cur.execute("select overflowerror()")
316 cur.fetchone()
317
318 def test_any_arguments(self):
319 cur = self.con.cursor()
320 cur.execute("select spam(?, ?)", (1, 2))
321 val = cur.fetchone()[0]
322 self.assertEqual(val, 2)
323
324 def test_empty_blob(self):
325 cur = self.con.execute("select isblob(x'')")
326 self.assertTrue(cur.fetchone()[0])
327
328 def test_nan_float(self):
329 cur = self.con.execute("select isnone(?)", (float("nan"),))
330 # SQLite has no concept of nan; it is converted to NULL
331 self.assertTrue(cur.fetchone()[0])
332
333 def test_too_large_int(self):
334 err = "Python int too large to convert to SQLite INTEGER"
335 self.assertRaisesRegex(OverflowError, err, self.con.execute,
336 "select spam(?)", (1 << 65,))
337
338 def test_non_contiguous_blob(self):
339 self.assertRaisesRegex(BufferError,
340 "underlying buffer is not C-contiguous",
341 self.con.execute, "select spam(?)",
342 (memoryview(b"blob")[::2],))
343
344 @with_tracebacks(BufferError, regex="buffer.*contiguous")
345 def test_return_non_contiguous_blob(self):
346 with self.assertRaises(sqlite.OperationalError):
347 cur = self.con.execute("select return_noncont_blob()")
348 cur.fetchone()
349
350 def test_param_surrogates(self):
351 self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed",
352 self.con.execute, "select spam(?)",
353 ("\ud803\ude6d",))
354
355 def test_func_params(self):
356 results = []
357 def append_result(arg):
358 results.append((arg, type(arg)))
359 self.con.create_function("test_params", 1, append_result)
360
361 dataset = [
362 (42, int),
363 (-1, int),
364 (1234567890123456789, int),
365 (4611686018427387905, int), # 63-bit int with non-zero low bits
366 (3.14, float),
367 (float('inf'), float),
368 ("text", str),
369 ("1\x002", str),
370 ("\u02e2q\u02e1\u2071\u1d57\u1d49", str),
371 (b"blob", bytes),
372 (bytearray(range(2)), bytes),
373 (memoryview(b"blob"), bytes),
374 (None, type(None)),
375 ]
376 for val, _ in dataset:
377 cur = self.con.execute("select test_params(?)", (val,))
378 cur.fetchone()
379 self.assertEqual(dataset, results)
380
381 # Regarding deterministic functions:
382 #
383 # Between 3.8.3 and 3.15.0, deterministic functions were only used to
384 # optimize inner loops, so for those versions we can only test if the
385 # sqlite machinery has factored out a call or not. From 3.15.0 and onward,
386 # deterministic functions were permitted in WHERE clauses of partial
387 # indices, which allows testing based on syntax, iso. the query optimizer.
388 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
389 def test_func_non_deterministic(self):
390 mock = Mock(return_value=None)
391 self.con.create_function("nondeterministic", 0, mock, deterministic=False)
392 if sqlite.sqlite_version_info < (3, 15, 0):
393 self.con.execute("select nondeterministic() = nondeterministic()")
394 self.assertEqual(mock.call_count, 2)
395 else:
396 with self.assertRaises(sqlite.OperationalError):
397 self.con.execute("create index t on test(t) where nondeterministic() is not null")
398
399 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
400 def test_func_deterministic(self):
401 mock = Mock(return_value=None)
402 self.con.create_function("deterministic", 0, mock, deterministic=True)
403 if sqlite.sqlite_version_info < (3, 15, 0):
404 self.con.execute("select deterministic() = deterministic()")
405 self.assertEqual(mock.call_count, 1)
406 else:
407 try:
408 self.con.execute("create index t on test(t) where deterministic() is not null")
409 except sqlite.OperationalError:
410 self.fail("Unexpected failure while creating partial index")
411
412 @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed")
413 def test_func_deterministic_not_supported(self):
414 with self.assertRaises(sqlite.NotSupportedError):
415 self.con.create_function("deterministic", 0, int, deterministic=True)
416
417 def test_func_deterministic_keyword_only(self):
418 with self.assertRaises(TypeError):
419 self.con.create_function("deterministic", 0, int, True)
420
421 def test_function_destructor_via_gc(self):
422 # See bpo-44304: The destructor of the user function can
423 # crash if is called without the GIL from the gc functions
424 dest = sqlite.connect(':memory:')
425 def md5sum(t):
426 return
427
428 dest.create_function("md5", 1, md5sum)
429 x = dest("create table lang (name, first_appeared)")
430 del md5sum, dest
431
432 y = [x]
433 y.append(y)
434
435 del x,y
436 gc_collect()
437
438 @with_tracebacks(OverflowError)
439 def test_func_return_too_large_int(self):
440 cur = self.con.cursor()
441 msg = "string or blob too big"
442 for value in 2**63, -2**63-1, 2**64:
443 self.con.create_function("largeint", 0, lambda value=value: value)
444 with self.assertRaisesRegex(sqlite.DataError, msg):
445 cur.execute("select largeint()")
446
447 @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
448 def test_func_return_text_with_surrogates(self):
449 cur = self.con.cursor()
450 self.con.create_function("pychr", 1, chr)
451 for value in 0xd8ff, 0xdcff:
452 with self.assertRaises(sqlite.OperationalError):
453 cur.execute("select pychr(?)", (value,))
454
455 @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
456 @bigmemtest(size=2**31, memuse=3, dry_run=False)
457 def test_func_return_too_large_text(self, size):
458 cur = self.con.cursor()
459 for size in 2**31-1, 2**31:
460 self.con.create_function("largetext", 0, lambda size=size: "b" * size)
461 with self.assertRaises(sqlite.DataError):
462 cur.execute("select largetext()")
463
464 @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
465 @bigmemtest(size=2**31, memuse=2, dry_run=False)
466 def test_func_return_too_large_blob(self, size):
467 cur = self.con.cursor()
468 for size in 2**31-1, 2**31:
469 self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
470 with self.assertRaises(sqlite.DataError):
471 cur.execute("select largeblob()")
472
473 def test_func_return_illegal_value(self):
474 self.con.create_function("badreturn", 0, lambda: self)
475 msg = "user-defined function raised exception"
476 self.assertRaisesRegex(sqlite.OperationalError, msg,
477 self.con.execute, "select badreturn()")
478
479
480 class ESC[4;38;5;81mWindowSumInt:
481 def __init__(self):
482 self.count = 0
483
484 def step(self, value):
485 self.count += value
486
487 def value(self):
488 return self.count
489
490 def inverse(self, value):
491 self.count -= value
492
493 def finalize(self):
494 return self.count
495
496 class ESC[4;38;5;81mBadWindow(ESC[4;38;5;149mException):
497 pass
498
499
500 @unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
501 "Requires SQLite 3.25.0 or newer")
502 class ESC[4;38;5;81mWindowFunctionTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
503 def setUp(self):
504 self.con = sqlite.connect(":memory:")
505 self.cur = self.con.cursor()
506
507 # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
508 values = [
509 ("a", 4),
510 ("b", 5),
511 ("c", 3),
512 ("d", 8),
513 ("e", 1),
514 ]
515 with self.con:
516 self.con.execute("create table test(x, y)")
517 self.con.executemany("insert into test values(?, ?)", values)
518 self.expected = [
519 ("a", 9),
520 ("b", 12),
521 ("c", 16),
522 ("d", 12),
523 ("e", 9),
524 ]
525 self.query = """
526 select x, %s(y) over (
527 order by x rows between 1 preceding and 1 following
528 ) as sum_y
529 from test order by x
530 """
531 self.con.create_window_function("sumint", 1, WindowSumInt)
532
533 def test_win_sum_int(self):
534 self.cur.execute(self.query % "sumint")
535 self.assertEqual(self.cur.fetchall(), self.expected)
536
537 def test_win_error_on_create(self):
538 self.assertRaises(sqlite.ProgrammingError,
539 self.con.create_window_function,
540 "shouldfail", -100, WindowSumInt)
541
542 @with_tracebacks(BadWindow)
543 def test_win_exception_in_method(self):
544 for meth in "__init__", "step", "value", "inverse":
545 with self.subTest(meth=meth):
546 with patch.object(WindowSumInt, meth, side_effect=BadWindow):
547 name = f"exc_{meth}"
548 self.con.create_window_function(name, 1, WindowSumInt)
549 msg = f"'{meth}' method raised error"
550 with self.assertRaisesRegex(sqlite.OperationalError, msg):
551 self.cur.execute(self.query % name)
552 self.cur.fetchall()
553
554 @with_tracebacks(BadWindow)
555 def test_win_exception_in_finalize(self):
556 # Note: SQLite does not (as of version 3.38.0) propagate finalize
557 # callback errors to sqlite3_step(); this implies that OperationalError
558 # is _not_ raised.
559 with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
560 name = f"exception_in_finalize"
561 self.con.create_window_function(name, 1, WindowSumInt)
562 self.cur.execute(self.query % name)
563 self.cur.fetchall()
564
565 @with_tracebacks(AttributeError)
566 def test_win_missing_method(self):
567 class ESC[4;38;5;81mMissingValue:
568 def step(self, x): pass
569 def inverse(self, x): pass
570 def finalize(self): return 42
571
572 class ESC[4;38;5;81mMissingInverse:
573 def step(self, x): pass
574 def value(self): return 42
575 def finalize(self): return 42
576
577 class ESC[4;38;5;81mMissingStep:
578 def value(self): return 42
579 def inverse(self, x): pass
580 def finalize(self): return 42
581
582 dataset = (
583 ("step", MissingStep),
584 ("value", MissingValue),
585 ("inverse", MissingInverse),
586 )
587 for meth, cls in dataset:
588 with self.subTest(meth=meth, cls=cls):
589 name = f"exc_{meth}"
590 self.con.create_window_function(name, 1, cls)
591 with self.assertRaisesRegex(sqlite.OperationalError,
592 f"'{meth}' method not defined"):
593 self.cur.execute(self.query % name)
594 self.cur.fetchall()
595
596 @with_tracebacks(AttributeError)
597 def test_win_missing_finalize(self):
598 # Note: SQLite does not (as of version 3.38.0) propagate finalize
599 # callback errors to sqlite3_step(); this implies that OperationalError
600 # is _not_ raised.
601 class ESC[4;38;5;81mMissingFinalize:
602 def step(self, x): pass
603 def value(self): return 42
604 def inverse(self, x): pass
605
606 name = "missing_finalize"
607 self.con.create_window_function(name, 1, MissingFinalize)
608 self.cur.execute(self.query % name)
609 self.cur.fetchall()
610
611 def test_win_clear_function(self):
612 self.con.create_window_function("sumint", 1, None)
613 self.assertRaises(sqlite.OperationalError, self.cur.execute,
614 self.query % "sumint")
615
616 def test_win_redefine_function(self):
617 # Redefine WindowSumInt; adjust the expected results accordingly.
618 class ESC[4;38;5;81mRedefined(ESC[4;38;5;149mWindowSumInt):
619 def step(self, value): self.count += value * 2
620 def inverse(self, value): self.count -= value * 2
621 expected = [(v[0], v[1]*2) for v in self.expected]
622
623 self.con.create_window_function("sumint", 1, Redefined)
624 self.cur.execute(self.query % "sumint")
625 self.assertEqual(self.cur.fetchall(), expected)
626
627 def test_win_error_value_return(self):
628 class ESC[4;38;5;81mErrorValueReturn:
629 def __init__(self): pass
630 def step(self, x): pass
631 def value(self): return 1 << 65
632
633 self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
634 self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
635 self.cur.execute, self.query % "err_val_ret")
636
637
638 class ESC[4;38;5;81mAggregateTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
639 def setUp(self):
640 self.con = sqlite.connect(":memory:")
641 cur = self.con.cursor()
642 cur.execute("""
643 create table test(
644 t text,
645 i integer,
646 f float,
647 n,
648 b blob
649 )
650 """)
651 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
652 ("foo", 5, 3.14, None, memoryview(b"blob"),))
653
654 self.con.create_aggregate("nostep", 1, AggrNoStep)
655 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
656 self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
657 self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
658 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
659 self.con.create_aggregate("checkType", 2, AggrCheckType)
660 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
661 self.con.create_aggregate("mysum", 1, AggrSum)
662 self.con.create_aggregate("aggtxt", 1, AggrText)
663
664 def tearDown(self):
665 #self.cur.close()
666 #self.con.close()
667 pass
668
669 def test_aggr_error_on_create(self):
670 with self.assertRaises(sqlite.OperationalError):
671 self.con.create_function("bla", -100, AggrSum)
672
673 @with_tracebacks(AttributeError, name="AggrNoStep")
674 def test_aggr_no_step(self):
675 cur = self.con.cursor()
676 with self.assertRaises(sqlite.OperationalError) as cm:
677 cur.execute("select nostep(t) from test")
678 self.assertEqual(str(cm.exception),
679 "user-defined aggregate's 'step' method not defined")
680
681 def test_aggr_no_finalize(self):
682 cur = self.con.cursor()
683 msg = "user-defined aggregate's 'finalize' method not defined"
684 with self.assertRaisesRegex(sqlite.OperationalError, msg):
685 cur.execute("select nofinalize(t) from test")
686 val = cur.fetchone()[0]
687
688 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
689 def test_aggr_exception_in_init(self):
690 cur = self.con.cursor()
691 with self.assertRaises(sqlite.OperationalError) as cm:
692 cur.execute("select excInit(t) from test")
693 val = cur.fetchone()[0]
694 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
695
696 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
697 def test_aggr_exception_in_step(self):
698 cur = self.con.cursor()
699 with self.assertRaises(sqlite.OperationalError) as cm:
700 cur.execute("select excStep(t) from test")
701 val = cur.fetchone()[0]
702 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
703
704 @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
705 def test_aggr_exception_in_finalize(self):
706 cur = self.con.cursor()
707 with self.assertRaises(sqlite.OperationalError) as cm:
708 cur.execute("select excFinalize(t) from test")
709 val = cur.fetchone()[0]
710 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
711
712 def test_aggr_check_param_str(self):
713 cur = self.con.cursor()
714 cur.execute("select checkTypes('str', ?, ?)", ("foo", str()))
715 val = cur.fetchone()[0]
716 self.assertEqual(val, 2)
717
718 def test_aggr_check_param_int(self):
719 cur = self.con.cursor()
720 cur.execute("select checkType('int', ?)", (42,))
721 val = cur.fetchone()[0]
722 self.assertEqual(val, 1)
723
724 def test_aggr_check_params_int(self):
725 cur = self.con.cursor()
726 cur.execute("select checkTypes('int', ?, ?)", (42, 24))
727 val = cur.fetchone()[0]
728 self.assertEqual(val, 2)
729
730 def test_aggr_check_param_float(self):
731 cur = self.con.cursor()
732 cur.execute("select checkType('float', ?)", (3.14,))
733 val = cur.fetchone()[0]
734 self.assertEqual(val, 1)
735
736 def test_aggr_check_param_none(self):
737 cur = self.con.cursor()
738 cur.execute("select checkType('None', ?)", (None,))
739 val = cur.fetchone()[0]
740 self.assertEqual(val, 1)
741
742 def test_aggr_check_param_blob(self):
743 cur = self.con.cursor()
744 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
745 val = cur.fetchone()[0]
746 self.assertEqual(val, 1)
747
748 def test_aggr_check_aggr_sum(self):
749 cur = self.con.cursor()
750 cur.execute("delete from test")
751 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
752 cur.execute("select mysum(i) from test")
753 val = cur.fetchone()[0]
754 self.assertEqual(val, 60)
755
756 def test_aggr_no_match(self):
757 cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0")
758 val = cur.fetchone()[0]
759 self.assertIsNone(val)
760
761 def test_aggr_text(self):
762 cur = self.con.cursor()
763 for txt in ["foo", "1\x002"]:
764 with self.subTest(txt=txt):
765 cur.execute("select aggtxt(?) from test", (txt,))
766 val = cur.fetchone()[0]
767 self.assertEqual(val, txt)
768
769
770 class ESC[4;38;5;81mAuthorizerTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
771 @staticmethod
772 def authorizer_cb(action, arg1, arg2, dbname, source):
773 if action != sqlite.SQLITE_SELECT:
774 return sqlite.SQLITE_DENY
775 if arg2 == 'c2' or arg1 == 't2':
776 return sqlite.SQLITE_DENY
777 return sqlite.SQLITE_OK
778
779 def setUp(self):
780 self.con = sqlite.connect(":memory:")
781 self.con.executescript("""
782 create table t1 (c1, c2);
783 create table t2 (c1, c2);
784 insert into t1 (c1, c2) values (1, 2);
785 insert into t2 (c1, c2) values (4, 5);
786 """)
787
788 # For our security test:
789 self.con.execute("select c2 from t2")
790
791 self.con.set_authorizer(self.authorizer_cb)
792
793 def tearDown(self):
794 pass
795
796 def test_table_access(self):
797 with self.assertRaises(sqlite.DatabaseError) as cm:
798 self.con.execute("select * from t2")
799 self.assertIn('prohibited', str(cm.exception))
800
801 def test_column_access(self):
802 with self.assertRaises(sqlite.DatabaseError) as cm:
803 self.con.execute("select c2 from t1")
804 self.assertIn('prohibited', str(cm.exception))
805
806 def test_clear_authorizer(self):
807 self.con.set_authorizer(None)
808 self.con.execute("select * from t2")
809 self.con.execute("select c2 from t1")
810
811
812 class ESC[4;38;5;81mAuthorizerRaiseExceptionTests(ESC[4;38;5;149mAuthorizerTests):
813 @staticmethod
814 def authorizer_cb(action, arg1, arg2, dbname, source):
815 if action != sqlite.SQLITE_SELECT:
816 raise ValueError
817 if arg2 == 'c2' or arg1 == 't2':
818 raise ValueError
819 return sqlite.SQLITE_OK
820
821 @with_tracebacks(ValueError, name="authorizer_cb")
822 def test_table_access(self):
823 super().test_table_access()
824
825 @with_tracebacks(ValueError, name="authorizer_cb")
826 def test_column_access(self):
827 super().test_table_access()
828
829 class ESC[4;38;5;81mAuthorizerIllegalTypeTests(ESC[4;38;5;149mAuthorizerTests):
830 @staticmethod
831 def authorizer_cb(action, arg1, arg2, dbname, source):
832 if action != sqlite.SQLITE_SELECT:
833 return 0.0
834 if arg2 == 'c2' or arg1 == 't2':
835 return 0.0
836 return sqlite.SQLITE_OK
837
838 class ESC[4;38;5;81mAuthorizerLargeIntegerTests(ESC[4;38;5;149mAuthorizerTests):
839 @staticmethod
840 def authorizer_cb(action, arg1, arg2, dbname, source):
841 if action != sqlite.SQLITE_SELECT:
842 return 2**32
843 if arg2 == 'c2' or arg1 == 't2':
844 return 2**32
845 return sqlite.SQLITE_OK
846
847
848 if __name__ == "__main__":
849 unittest.main()