(root)/
Python-3.11.7/
Lib/
test/
test_sqlite3/
test_userfunctions.py
       1  # pysqlite2/test/userfunctions.py: tests for user-defined functions and
       2  #                                  aggregates.
       3  #
       4  # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
       5  #
       6  # This file is part of pysqlite.
       7  #
       8  # This software is provided 'as-is', without any express or implied
       9  # warranty.  In no event will the authors be held liable for any damages
      10  # arising from the use of this software.
      11  #
      12  # Permission is granted to anyone to use this software for any purpose,
      13  # including commercial applications, and to alter it and redistribute it
      14  # freely, subject to the following restrictions:
      15  #
      16  # 1. The origin of this software must not be misrepresented; you must not
      17  #    claim that you wrote the original software. If you use this software
      18  #    in a product, an acknowledgment in the product documentation would be
      19  #    appreciated but is not required.
      20  # 2. Altered source versions must be plainly marked as such, and must not be
      21  #    misrepresented as being the original software.
      22  # 3. This notice may not be removed or altered from any source distribution.
      23  
      24  import contextlib
      25  import functools
      26  import io
      27  import re
      28  import sys
      29  import unittest
      30  import sqlite3 as sqlite
      31  
      32  from unittest.mock import Mock, patch
      33  from test.support import bigmemtest, catch_unraisable_exception, gc_collect
      34  
      35  from test.test_sqlite3.test_dbapi import cx_limit
      36  
      37  
      38  def with_tracebacks(exc, regex="", name=""):
      39      """Convenience decorator for testing callback tracebacks."""
      40      def decorator(func):
      41          _regex = re.compile(regex) if regex else None
      42          @functools.wraps(func)
      43          def wrapper(self, *args, **kwargs):
      44              with catch_unraisable_exception() as cm:
      45                  # First, run the test with traceback enabled.
      46                  with check_tracebacks(self, cm, exc, _regex, name):
      47                      func(self, *args, **kwargs)
      48  
      49              # Then run the test with traceback disabled.
      50              func(self, *args, **kwargs)
      51          return wrapper
      52      return decorator
      53  
      54  
      55  @contextlib.contextmanager
      56  def check_tracebacks(self, cm, exc, regex, obj_name):
      57      """Convenience context manager for testing callback tracebacks."""
      58      sqlite.enable_callback_tracebacks(True)
      59      try:
      60          buf = io.StringIO()
      61          with contextlib.redirect_stderr(buf):
      62              yield
      63  
      64          self.assertEqual(cm.unraisable.exc_type, exc)
      65          if regex:
      66              msg = str(cm.unraisable.exc_value)
      67              self.assertIsNotNone(regex.search(msg))
      68          if obj_name:
      69              self.assertEqual(cm.unraisable.object.__name__, obj_name)
      70      finally:
      71          sqlite.enable_callback_tracebacks(False)
      72  
      73  
      74  def func_returntext():
      75      return "foo"
      76  def func_returntextwithnull():
      77      return "1\x002"
      78  def func_returnunicode():
      79      return "bar"
      80  def func_returnint():
      81      return 42
      82  def func_returnfloat():
      83      return 3.14
      84  def func_returnnull():
      85      return None
      86  def func_returnblob():
      87      return b"blob"
      88  def func_returnlonglong():
      89      return 1<<31
      90  def func_raiseexception():
      91      5/0
      92  def func_memoryerror():
      93      raise MemoryError
      94  def func_overflowerror():
      95      raise OverflowError
      96  
      97  class ESC[4;38;5;81mAggrNoStep:
      98      def __init__(self):
      99          pass
     100  
     101      def finalize(self):
     102          return 1
     103  
     104  class ESC[4;38;5;81mAggrNoFinalize:
     105      def __init__(self):
     106          pass
     107  
     108      def step(self, x):
     109          pass
     110  
     111  class ESC[4;38;5;81mAggrExceptionInInit:
     112      def __init__(self):
     113          5/0
     114  
     115      def step(self, x):
     116          pass
     117  
     118      def finalize(self):
     119          pass
     120  
     121  class ESC[4;38;5;81mAggrExceptionInStep:
     122      def __init__(self):
     123          pass
     124  
     125      def step(self, x):
     126          5/0
     127  
     128      def finalize(self):
     129          return 42
     130  
     131  class ESC[4;38;5;81mAggrExceptionInFinalize:
     132      def __init__(self):
     133          pass
     134  
     135      def step(self, x):
     136          pass
     137  
     138      def finalize(self):
     139          5/0
     140  
     141  class ESC[4;38;5;81mAggrCheckType:
     142      def __init__(self):
     143          self.val = None
     144  
     145      def step(self, whichType, val):
     146          theType = {"str": str, "int": int, "float": float, "None": type(None),
     147                     "blob": bytes}
     148          self.val = int(theType[whichType] is type(val))
     149  
     150      def finalize(self):
     151          return self.val
     152  
     153  class ESC[4;38;5;81mAggrCheckTypes:
     154      def __init__(self):
     155          self.val = 0
     156  
     157      def step(self, whichType, *vals):
     158          theType = {"str": str, "int": int, "float": float, "None": type(None),
     159                     "blob": bytes}
     160          for val in vals:
     161              self.val += int(theType[whichType] is type(val))
     162  
     163      def finalize(self):
     164          return self.val
     165  
     166  class ESC[4;38;5;81mAggrSum:
     167      def __init__(self):
     168          self.val = 0.0
     169  
     170      def step(self, val):
     171          self.val += val
     172  
     173      def finalize(self):
     174          return self.val
     175  
     176  class ESC[4;38;5;81mAggrText:
     177      def __init__(self):
     178          self.txt = ""
     179      def step(self, txt):
     180          self.txt = self.txt + txt
     181      def finalize(self):
     182          return self.txt
     183  
     184  
     185  class ESC[4;38;5;81mFunctionTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     186      def setUp(self):
     187          self.con = sqlite.connect(":memory:")
     188  
     189          self.con.create_function("returntext", 0, func_returntext)
     190          self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
     191          self.con.create_function("returnunicode", 0, func_returnunicode)
     192          self.con.create_function("returnint", 0, func_returnint)
     193          self.con.create_function("returnfloat", 0, func_returnfloat)
     194          self.con.create_function("returnnull", 0, func_returnnull)
     195          self.con.create_function("returnblob", 0, func_returnblob)
     196          self.con.create_function("returnlonglong", 0, func_returnlonglong)
     197          self.con.create_function("returnnan", 0, lambda: float("nan"))
     198          self.con.create_function("return_noncont_blob", 0,
     199                                   lambda: memoryview(b"blob")[::2])
     200          self.con.create_function("raiseexception", 0, func_raiseexception)
     201          self.con.create_function("memoryerror", 0, func_memoryerror)
     202          self.con.create_function("overflowerror", 0, func_overflowerror)
     203  
     204          self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes))
     205          self.con.create_function("isnone", 1, lambda x: x is None)
     206          self.con.create_function("spam", -1, lambda *x: len(x))
     207          self.con.execute("create table test(t text)")
     208  
     209      def tearDown(self):
     210          self.con.close()
     211  
     212      def test_func_error_on_create(self):
     213          with self.assertRaises(sqlite.OperationalError):
     214              self.con.create_function("bla", -100, lambda x: 2*x)
     215  
     216      def test_func_too_many_args(self):
     217          category = sqlite.SQLITE_LIMIT_FUNCTION_ARG
     218          msg = "too many arguments on function"
     219          with cx_limit(self.con, category=category, limit=1):
     220              self.con.execute("select abs(-1)");
     221              with self.assertRaisesRegex(sqlite.OperationalError, msg):
     222                  self.con.execute("select max(1, 2)");
     223  
     224      def test_func_ref_count(self):
     225          def getfunc():
     226              def f():
     227                  return 1
     228              return f
     229          f = getfunc()
     230          globals()["foo"] = f
     231          # self.con.create_function("reftest", 0, getfunc())
     232          self.con.create_function("reftest", 0, f)
     233          cur = self.con.cursor()
     234          cur.execute("select reftest()")
     235  
     236      def test_func_return_text(self):
     237          cur = self.con.cursor()
     238          cur.execute("select returntext()")
     239          val = cur.fetchone()[0]
     240          self.assertEqual(type(val), str)
     241          self.assertEqual(val, "foo")
     242  
     243      def test_func_return_text_with_null_char(self):
     244          cur = self.con.cursor()
     245          res = cur.execute("select returntextwithnull()").fetchone()[0]
     246          self.assertEqual(type(res), str)
     247          self.assertEqual(res, "1\x002")
     248  
     249      def test_func_return_unicode(self):
     250          cur = self.con.cursor()
     251          cur.execute("select returnunicode()")
     252          val = cur.fetchone()[0]
     253          self.assertEqual(type(val), str)
     254          self.assertEqual(val, "bar")
     255  
     256      def test_func_return_int(self):
     257          cur = self.con.cursor()
     258          cur.execute("select returnint()")
     259          val = cur.fetchone()[0]
     260          self.assertEqual(type(val), int)
     261          self.assertEqual(val, 42)
     262  
     263      def test_func_return_float(self):
     264          cur = self.con.cursor()
     265          cur.execute("select returnfloat()")
     266          val = cur.fetchone()[0]
     267          self.assertEqual(type(val), float)
     268          if val < 3.139 or val > 3.141:
     269              self.fail("wrong value")
     270  
     271      def test_func_return_null(self):
     272          cur = self.con.cursor()
     273          cur.execute("select returnnull()")
     274          val = cur.fetchone()[0]
     275          self.assertEqual(type(val), type(None))
     276          self.assertEqual(val, None)
     277  
     278      def test_func_return_blob(self):
     279          cur = self.con.cursor()
     280          cur.execute("select returnblob()")
     281          val = cur.fetchone()[0]
     282          self.assertEqual(type(val), bytes)
     283          self.assertEqual(val, b"blob")
     284  
     285      def test_func_return_long_long(self):
     286          cur = self.con.cursor()
     287          cur.execute("select returnlonglong()")
     288          val = cur.fetchone()[0]
     289          self.assertEqual(val, 1<<31)
     290  
     291      def test_func_return_nan(self):
     292          cur = self.con.cursor()
     293          cur.execute("select returnnan()")
     294          self.assertIsNone(cur.fetchone()[0])
     295  
     296      @with_tracebacks(ZeroDivisionError, name="func_raiseexception")
     297      def test_func_exception(self):
     298          cur = self.con.cursor()
     299          with self.assertRaises(sqlite.OperationalError) as cm:
     300              cur.execute("select raiseexception()")
     301              cur.fetchone()
     302          self.assertEqual(str(cm.exception), 'user-defined function raised exception')
     303  
     304      @with_tracebacks(MemoryError, name="func_memoryerror")
     305      def test_func_memory_error(self):
     306          cur = self.con.cursor()
     307          with self.assertRaises(MemoryError):
     308              cur.execute("select memoryerror()")
     309              cur.fetchone()
     310  
     311      @with_tracebacks(OverflowError, name="func_overflowerror")
     312      def test_func_overflow_error(self):
     313          cur = self.con.cursor()
     314          with self.assertRaises(sqlite.DataError):
     315              cur.execute("select overflowerror()")
     316              cur.fetchone()
     317  
     318      def test_any_arguments(self):
     319          cur = self.con.cursor()
     320          cur.execute("select spam(?, ?)", (1, 2))
     321          val = cur.fetchone()[0]
     322          self.assertEqual(val, 2)
     323  
     324      def test_empty_blob(self):
     325          cur = self.con.execute("select isblob(x'')")
     326          self.assertTrue(cur.fetchone()[0])
     327  
     328      def test_nan_float(self):
     329          cur = self.con.execute("select isnone(?)", (float("nan"),))
     330          # SQLite has no concept of nan; it is converted to NULL
     331          self.assertTrue(cur.fetchone()[0])
     332  
     333      def test_too_large_int(self):
     334          err = "Python int too large to convert to SQLite INTEGER"
     335          self.assertRaisesRegex(OverflowError, err, self.con.execute,
     336                                 "select spam(?)", (1 << 65,))
     337  
     338      def test_non_contiguous_blob(self):
     339          self.assertRaisesRegex(BufferError,
     340                                 "underlying buffer is not C-contiguous",
     341                                 self.con.execute, "select spam(?)",
     342                                 (memoryview(b"blob")[::2],))
     343  
     344      @with_tracebacks(BufferError, regex="buffer.*contiguous")
     345      def test_return_non_contiguous_blob(self):
     346          with self.assertRaises(sqlite.OperationalError):
     347              cur = self.con.execute("select return_noncont_blob()")
     348              cur.fetchone()
     349  
     350      def test_param_surrogates(self):
     351          self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed",
     352                                 self.con.execute, "select spam(?)",
     353                                 ("\ud803\ude6d",))
     354  
     355      def test_func_params(self):
     356          results = []
     357          def append_result(arg):
     358              results.append((arg, type(arg)))
     359          self.con.create_function("test_params", 1, append_result)
     360  
     361          dataset = [
     362              (42, int),
     363              (-1, int),
     364              (1234567890123456789, int),
     365              (4611686018427387905, int),  # 63-bit int with non-zero low bits
     366              (3.14, float),
     367              (float('inf'), float),
     368              ("text", str),
     369              ("1\x002", str),
     370              ("\u02e2q\u02e1\u2071\u1d57\u1d49", str),
     371              (b"blob", bytes),
     372              (bytearray(range(2)), bytes),
     373              (memoryview(b"blob"), bytes),
     374              (None, type(None)),
     375          ]
     376          for val, _ in dataset:
     377              cur = self.con.execute("select test_params(?)", (val,))
     378              cur.fetchone()
     379          self.assertEqual(dataset, results)
     380  
     381      # Regarding deterministic functions:
     382      #
     383      # Between 3.8.3 and 3.15.0, deterministic functions were only used to
     384      # optimize inner loops, so for those versions we can only test if the
     385      # sqlite machinery has factored out a call or not. From 3.15.0 and onward,
     386      # deterministic functions were permitted in WHERE clauses of partial
     387      # indices, which allows testing based on syntax, iso. the query optimizer.
     388      @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
     389      def test_func_non_deterministic(self):
     390          mock = Mock(return_value=None)
     391          self.con.create_function("nondeterministic", 0, mock, deterministic=False)
     392          if sqlite.sqlite_version_info < (3, 15, 0):
     393              self.con.execute("select nondeterministic() = nondeterministic()")
     394              self.assertEqual(mock.call_count, 2)
     395          else:
     396              with self.assertRaises(sqlite.OperationalError):
     397                  self.con.execute("create index t on test(t) where nondeterministic() is not null")
     398  
     399      @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
     400      def test_func_deterministic(self):
     401          mock = Mock(return_value=None)
     402          self.con.create_function("deterministic", 0, mock, deterministic=True)
     403          if sqlite.sqlite_version_info < (3, 15, 0):
     404              self.con.execute("select deterministic() = deterministic()")
     405              self.assertEqual(mock.call_count, 1)
     406          else:
     407              try:
     408                  self.con.execute("create index t on test(t) where deterministic() is not null")
     409              except sqlite.OperationalError:
     410                  self.fail("Unexpected failure while creating partial index")
     411  
     412      @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed")
     413      def test_func_deterministic_not_supported(self):
     414          with self.assertRaises(sqlite.NotSupportedError):
     415              self.con.create_function("deterministic", 0, int, deterministic=True)
     416  
     417      def test_func_deterministic_keyword_only(self):
     418          with self.assertRaises(TypeError):
     419              self.con.create_function("deterministic", 0, int, True)
     420  
     421      def test_function_destructor_via_gc(self):
     422          # See bpo-44304: The destructor of the user function can
     423          # crash if is called without the GIL from the gc functions
     424          dest = sqlite.connect(':memory:')
     425          def md5sum(t):
     426              return
     427  
     428          dest.create_function("md5", 1, md5sum)
     429          x = dest("create table lang (name, first_appeared)")
     430          del md5sum, dest
     431  
     432          y = [x]
     433          y.append(y)
     434  
     435          del x,y
     436          gc_collect()
     437  
     438      @with_tracebacks(OverflowError)
     439      def test_func_return_too_large_int(self):
     440          cur = self.con.cursor()
     441          msg = "string or blob too big"
     442          for value in 2**63, -2**63-1, 2**64:
     443              self.con.create_function("largeint", 0, lambda value=value: value)
     444              with self.assertRaisesRegex(sqlite.DataError, msg):
     445                  cur.execute("select largeint()")
     446  
     447      @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
     448      def test_func_return_text_with_surrogates(self):
     449          cur = self.con.cursor()
     450          self.con.create_function("pychr", 1, chr)
     451          for value in 0xd8ff, 0xdcff:
     452              with self.assertRaises(sqlite.OperationalError):
     453                  cur.execute("select pychr(?)", (value,))
     454  
     455      @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
     456      @bigmemtest(size=2**31, memuse=3, dry_run=False)
     457      def test_func_return_too_large_text(self, size):
     458          cur = self.con.cursor()
     459          for size in 2**31-1, 2**31:
     460              self.con.create_function("largetext", 0, lambda size=size: "b" * size)
     461              with self.assertRaises(sqlite.DataError):
     462                  cur.execute("select largetext()")
     463  
     464      @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
     465      @bigmemtest(size=2**31, memuse=2, dry_run=False)
     466      def test_func_return_too_large_blob(self, size):
     467          cur = self.con.cursor()
     468          for size in 2**31-1, 2**31:
     469              self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
     470              with self.assertRaises(sqlite.DataError):
     471                  cur.execute("select largeblob()")
     472  
     473      def test_func_return_illegal_value(self):
     474          self.con.create_function("badreturn", 0, lambda: self)
     475          msg = "user-defined function raised exception"
     476          self.assertRaisesRegex(sqlite.OperationalError, msg,
     477                                 self.con.execute, "select badreturn()")
     478  
     479  
     480  class ESC[4;38;5;81mWindowSumInt:
     481      def __init__(self):
     482          self.count = 0
     483  
     484      def step(self, value):
     485          self.count += value
     486  
     487      def value(self):
     488          return self.count
     489  
     490      def inverse(self, value):
     491          self.count -= value
     492  
     493      def finalize(self):
     494          return self.count
     495  
     496  class ESC[4;38;5;81mBadWindow(ESC[4;38;5;149mException):
     497      pass
     498  
     499  
     500  @unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
     501                   "Requires SQLite 3.25.0 or newer")
     502  class ESC[4;38;5;81mWindowFunctionTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     503      def setUp(self):
     504          self.con = sqlite.connect(":memory:")
     505          self.cur = self.con.cursor()
     506  
     507          # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
     508          values = [
     509              ("a", 4),
     510              ("b", 5),
     511              ("c", 3),
     512              ("d", 8),
     513              ("e", 1),
     514          ]
     515          with self.con:
     516              self.con.execute("create table test(x, y)")
     517              self.con.executemany("insert into test values(?, ?)", values)
     518          self.expected = [
     519              ("a", 9),
     520              ("b", 12),
     521              ("c", 16),
     522              ("d", 12),
     523              ("e", 9),
     524          ]
     525          self.query = """
     526              select x, %s(y) over (
     527                  order by x rows between 1 preceding and 1 following
     528              ) as sum_y
     529              from test order by x
     530          """
     531          self.con.create_window_function("sumint", 1, WindowSumInt)
     532  
     533      def test_win_sum_int(self):
     534          self.cur.execute(self.query % "sumint")
     535          self.assertEqual(self.cur.fetchall(), self.expected)
     536  
     537      def test_win_error_on_create(self):
     538          self.assertRaises(sqlite.ProgrammingError,
     539                            self.con.create_window_function,
     540                            "shouldfail", -100, WindowSumInt)
     541  
     542      @with_tracebacks(BadWindow)
     543      def test_win_exception_in_method(self):
     544          for meth in "__init__", "step", "value", "inverse":
     545              with self.subTest(meth=meth):
     546                  with patch.object(WindowSumInt, meth, side_effect=BadWindow):
     547                      name = f"exc_{meth}"
     548                      self.con.create_window_function(name, 1, WindowSumInt)
     549                      msg = f"'{meth}' method raised error"
     550                      with self.assertRaisesRegex(sqlite.OperationalError, msg):
     551                          self.cur.execute(self.query % name)
     552                          self.cur.fetchall()
     553  
     554      @with_tracebacks(BadWindow)
     555      def test_win_exception_in_finalize(self):
     556          # Note: SQLite does not (as of version 3.38.0) propagate finalize
     557          # callback errors to sqlite3_step(); this implies that OperationalError
     558          # is _not_ raised.
     559          with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
     560              name = f"exception_in_finalize"
     561              self.con.create_window_function(name, 1, WindowSumInt)
     562              self.cur.execute(self.query % name)
     563              self.cur.fetchall()
     564  
     565      @with_tracebacks(AttributeError)
     566      def test_win_missing_method(self):
     567          class ESC[4;38;5;81mMissingValue:
     568              def step(self, x): pass
     569              def inverse(self, x): pass
     570              def finalize(self): return 42
     571  
     572          class ESC[4;38;5;81mMissingInverse:
     573              def step(self, x): pass
     574              def value(self): return 42
     575              def finalize(self): return 42
     576  
     577          class ESC[4;38;5;81mMissingStep:
     578              def value(self): return 42
     579              def inverse(self, x): pass
     580              def finalize(self): return 42
     581  
     582          dataset = (
     583              ("step", MissingStep),
     584              ("value", MissingValue),
     585              ("inverse", MissingInverse),
     586          )
     587          for meth, cls in dataset:
     588              with self.subTest(meth=meth, cls=cls):
     589                  name = f"exc_{meth}"
     590                  self.con.create_window_function(name, 1, cls)
     591                  with self.assertRaisesRegex(sqlite.OperationalError,
     592                                              f"'{meth}' method not defined"):
     593                      self.cur.execute(self.query % name)
     594                      self.cur.fetchall()
     595  
     596      @with_tracebacks(AttributeError)
     597      def test_win_missing_finalize(self):
     598          # Note: SQLite does not (as of version 3.38.0) propagate finalize
     599          # callback errors to sqlite3_step(); this implies that OperationalError
     600          # is _not_ raised.
     601          class ESC[4;38;5;81mMissingFinalize:
     602              def step(self, x): pass
     603              def value(self): return 42
     604              def inverse(self, x): pass
     605  
     606          name = "missing_finalize"
     607          self.con.create_window_function(name, 1, MissingFinalize)
     608          self.cur.execute(self.query % name)
     609          self.cur.fetchall()
     610  
     611      def test_win_clear_function(self):
     612          self.con.create_window_function("sumint", 1, None)
     613          self.assertRaises(sqlite.OperationalError, self.cur.execute,
     614                            self.query % "sumint")
     615  
     616      def test_win_redefine_function(self):
     617          # Redefine WindowSumInt; adjust the expected results accordingly.
     618          class ESC[4;38;5;81mRedefined(ESC[4;38;5;149mWindowSumInt):
     619              def step(self, value): self.count += value * 2
     620              def inverse(self, value): self.count -= value * 2
     621          expected = [(v[0], v[1]*2) for v in self.expected]
     622  
     623          self.con.create_window_function("sumint", 1, Redefined)
     624          self.cur.execute(self.query % "sumint")
     625          self.assertEqual(self.cur.fetchall(), expected)
     626  
     627      def test_win_error_value_return(self):
     628          class ESC[4;38;5;81mErrorValueReturn:
     629              def __init__(self): pass
     630              def step(self, x): pass
     631              def value(self): return 1 << 65
     632  
     633          self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
     634          self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
     635                                 self.cur.execute, self.query % "err_val_ret")
     636  
     637  
     638  class ESC[4;38;5;81mAggregateTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     639      def setUp(self):
     640          self.con = sqlite.connect(":memory:")
     641          cur = self.con.cursor()
     642          cur.execute("""
     643              create table test(
     644                  t text,
     645                  i integer,
     646                  f float,
     647                  n,
     648                  b blob
     649                  )
     650              """)
     651          cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
     652              ("foo", 5, 3.14, None, memoryview(b"blob"),))
     653  
     654          self.con.create_aggregate("nostep", 1, AggrNoStep)
     655          self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
     656          self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
     657          self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
     658          self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
     659          self.con.create_aggregate("checkType", 2, AggrCheckType)
     660          self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
     661          self.con.create_aggregate("mysum", 1, AggrSum)
     662          self.con.create_aggregate("aggtxt", 1, AggrText)
     663  
     664      def tearDown(self):
     665          #self.cur.close()
     666          #self.con.close()
     667          pass
     668  
     669      def test_aggr_error_on_create(self):
     670          with self.assertRaises(sqlite.OperationalError):
     671              self.con.create_function("bla", -100, AggrSum)
     672  
     673      @with_tracebacks(AttributeError, name="AggrNoStep")
     674      def test_aggr_no_step(self):
     675          cur = self.con.cursor()
     676          with self.assertRaises(sqlite.OperationalError) as cm:
     677              cur.execute("select nostep(t) from test")
     678          self.assertEqual(str(cm.exception),
     679                           "user-defined aggregate's 'step' method not defined")
     680  
     681      def test_aggr_no_finalize(self):
     682          cur = self.con.cursor()
     683          msg = "user-defined aggregate's 'finalize' method not defined"
     684          with self.assertRaisesRegex(sqlite.OperationalError, msg):
     685              cur.execute("select nofinalize(t) from test")
     686              val = cur.fetchone()[0]
     687  
     688      @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
     689      def test_aggr_exception_in_init(self):
     690          cur = self.con.cursor()
     691          with self.assertRaises(sqlite.OperationalError) as cm:
     692              cur.execute("select excInit(t) from test")
     693              val = cur.fetchone()[0]
     694          self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
     695  
     696      @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
     697      def test_aggr_exception_in_step(self):
     698          cur = self.con.cursor()
     699          with self.assertRaises(sqlite.OperationalError) as cm:
     700              cur.execute("select excStep(t) from test")
     701              val = cur.fetchone()[0]
     702          self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
     703  
     704      @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
     705      def test_aggr_exception_in_finalize(self):
     706          cur = self.con.cursor()
     707          with self.assertRaises(sqlite.OperationalError) as cm:
     708              cur.execute("select excFinalize(t) from test")
     709              val = cur.fetchone()[0]
     710          self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
     711  
     712      def test_aggr_check_param_str(self):
     713          cur = self.con.cursor()
     714          cur.execute("select checkTypes('str', ?, ?)", ("foo", str()))
     715          val = cur.fetchone()[0]
     716          self.assertEqual(val, 2)
     717  
     718      def test_aggr_check_param_int(self):
     719          cur = self.con.cursor()
     720          cur.execute("select checkType('int', ?)", (42,))
     721          val = cur.fetchone()[0]
     722          self.assertEqual(val, 1)
     723  
     724      def test_aggr_check_params_int(self):
     725          cur = self.con.cursor()
     726          cur.execute("select checkTypes('int', ?, ?)", (42, 24))
     727          val = cur.fetchone()[0]
     728          self.assertEqual(val, 2)
     729  
     730      def test_aggr_check_param_float(self):
     731          cur = self.con.cursor()
     732          cur.execute("select checkType('float', ?)", (3.14,))
     733          val = cur.fetchone()[0]
     734          self.assertEqual(val, 1)
     735  
     736      def test_aggr_check_param_none(self):
     737          cur = self.con.cursor()
     738          cur.execute("select checkType('None', ?)", (None,))
     739          val = cur.fetchone()[0]
     740          self.assertEqual(val, 1)
     741  
     742      def test_aggr_check_param_blob(self):
     743          cur = self.con.cursor()
     744          cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
     745          val = cur.fetchone()[0]
     746          self.assertEqual(val, 1)
     747  
     748      def test_aggr_check_aggr_sum(self):
     749          cur = self.con.cursor()
     750          cur.execute("delete from test")
     751          cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
     752          cur.execute("select mysum(i) from test")
     753          val = cur.fetchone()[0]
     754          self.assertEqual(val, 60)
     755  
     756      def test_aggr_no_match(self):
     757          cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0")
     758          val = cur.fetchone()[0]
     759          self.assertIsNone(val)
     760  
     761      def test_aggr_text(self):
     762          cur = self.con.cursor()
     763          for txt in ["foo", "1\x002"]:
     764              with self.subTest(txt=txt):
     765                  cur.execute("select aggtxt(?) from test", (txt,))
     766                  val = cur.fetchone()[0]
     767                  self.assertEqual(val, txt)
     768  
     769  
     770  class ESC[4;38;5;81mAuthorizerTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     771      @staticmethod
     772      def authorizer_cb(action, arg1, arg2, dbname, source):
     773          if action != sqlite.SQLITE_SELECT:
     774              return sqlite.SQLITE_DENY
     775          if arg2 == 'c2' or arg1 == 't2':
     776              return sqlite.SQLITE_DENY
     777          return sqlite.SQLITE_OK
     778  
     779      def setUp(self):
     780          self.con = sqlite.connect(":memory:")
     781          self.con.executescript("""
     782              create table t1 (c1, c2);
     783              create table t2 (c1, c2);
     784              insert into t1 (c1, c2) values (1, 2);
     785              insert into t2 (c1, c2) values (4, 5);
     786              """)
     787  
     788          # For our security test:
     789          self.con.execute("select c2 from t2")
     790  
     791          self.con.set_authorizer(self.authorizer_cb)
     792  
     793      def tearDown(self):
     794          pass
     795  
     796      def test_table_access(self):
     797          with self.assertRaises(sqlite.DatabaseError) as cm:
     798              self.con.execute("select * from t2")
     799          self.assertIn('prohibited', str(cm.exception))
     800  
     801      def test_column_access(self):
     802          with self.assertRaises(sqlite.DatabaseError) as cm:
     803              self.con.execute("select c2 from t1")
     804          self.assertIn('prohibited', str(cm.exception))
     805  
     806      def test_clear_authorizer(self):
     807          self.con.set_authorizer(None)
     808          self.con.execute("select * from t2")
     809          self.con.execute("select c2 from t1")
     810  
     811  
     812  class ESC[4;38;5;81mAuthorizerRaiseExceptionTests(ESC[4;38;5;149mAuthorizerTests):
     813      @staticmethod
     814      def authorizer_cb(action, arg1, arg2, dbname, source):
     815          if action != sqlite.SQLITE_SELECT:
     816              raise ValueError
     817          if arg2 == 'c2' or arg1 == 't2':
     818              raise ValueError
     819          return sqlite.SQLITE_OK
     820  
     821      @with_tracebacks(ValueError, name="authorizer_cb")
     822      def test_table_access(self):
     823          super().test_table_access()
     824  
     825      @with_tracebacks(ValueError, name="authorizer_cb")
     826      def test_column_access(self):
     827          super().test_table_access()
     828  
     829  class ESC[4;38;5;81mAuthorizerIllegalTypeTests(ESC[4;38;5;149mAuthorizerTests):
     830      @staticmethod
     831      def authorizer_cb(action, arg1, arg2, dbname, source):
     832          if action != sqlite.SQLITE_SELECT:
     833              return 0.0
     834          if arg2 == 'c2' or arg1 == 't2':
     835              return 0.0
     836          return sqlite.SQLITE_OK
     837  
     838  class ESC[4;38;5;81mAuthorizerLargeIntegerTests(ESC[4;38;5;149mAuthorizerTests):
     839      @staticmethod
     840      def authorizer_cb(action, arg1, arg2, dbname, source):
     841          if action != sqlite.SQLITE_SELECT:
     842              return 2**32
     843          if arg2 == 'c2' or arg1 == 't2':
     844              return 2**32
     845          return sqlite.SQLITE_OK
     846  
     847  
     848  if __name__ == "__main__":
     849      unittest.main()