(root)/
Python-3.11.7/
Lib/
test/
test_sqlite3/
test_hooks.py
       1  # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
       2  #
       3  # Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>
       4  #
       5  # This file is part of pysqlite.
       6  #
       7  # This software is provided 'as-is', without any express or implied
       8  # warranty.  In no event will the authors be held liable for any damages
       9  # arising from the use of this software.
      10  #
      11  # Permission is granted to anyone to use this software for any purpose,
      12  # including commercial applications, and to alter it and redistribute it
      13  # freely, subject to the following restrictions:
      14  #
      15  # 1. The origin of this software must not be misrepresented; you must not
      16  #    claim that you wrote the original software. If you use this software
      17  #    in a product, an acknowledgment in the product documentation would be
      18  #    appreciated but is not required.
      19  # 2. Altered source versions must be plainly marked as such, and must not be
      20  #    misrepresented as being the original software.
      21  # 3. This notice may not be removed or altered from any source distribution.
      22  
      23  import contextlib
      24  import sqlite3 as sqlite
      25  import unittest
      26  
      27  from test.support.os_helper import TESTFN, unlink
      28  
      29  from test.test_sqlite3.test_dbapi import memory_database, cx_limit
      30  from test.test_sqlite3.test_userfunctions import with_tracebacks
      31  
      32  
      33  class ESC[4;38;5;81mCollationTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      34      def test_create_collation_not_string(self):
      35          con = sqlite.connect(":memory:")
      36          with self.assertRaises(TypeError):
      37              con.create_collation(None, lambda x, y: (x > y) - (x < y))
      38  
      39      def test_create_collation_not_callable(self):
      40          con = sqlite.connect(":memory:")
      41          with self.assertRaises(TypeError) as cm:
      42              con.create_collation("X", 42)
      43          self.assertEqual(str(cm.exception), 'parameter must be callable')
      44  
      45      def test_create_collation_not_ascii(self):
      46          con = sqlite.connect(":memory:")
      47          con.create_collation("collä", lambda x, y: (x > y) - (x < y))
      48  
      49      def test_create_collation_bad_upper(self):
      50          class ESC[4;38;5;81mBadUpperStr(ESC[4;38;5;149mstr):
      51              def upper(self):
      52                  return None
      53          con = sqlite.connect(":memory:")
      54          mycoll = lambda x, y: -((x > y) - (x < y))
      55          con.create_collation(BadUpperStr("mycoll"), mycoll)
      56          result = con.execute("""
      57              select x from (
      58              select 'a' as x
      59              union
      60              select 'b' as x
      61              ) order by x collate mycoll
      62              """).fetchall()
      63          self.assertEqual(result[0][0], 'b')
      64          self.assertEqual(result[1][0], 'a')
      65  
      66      def test_collation_is_used(self):
      67          def mycoll(x, y):
      68              # reverse order
      69              return -((x > y) - (x < y))
      70  
      71          con = sqlite.connect(":memory:")
      72          con.create_collation("mycoll", mycoll)
      73          sql = """
      74              select x from (
      75              select 'a' as x
      76              union
      77              select 'b' as x
      78              union
      79              select 'c' as x
      80              ) order by x collate mycoll
      81              """
      82          result = con.execute(sql).fetchall()
      83          self.assertEqual(result, [('c',), ('b',), ('a',)],
      84                           msg='the expected order was not returned')
      85  
      86          con.create_collation("mycoll", None)
      87          with self.assertRaises(sqlite.OperationalError) as cm:
      88              result = con.execute(sql).fetchall()
      89          self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
      90  
      91      def test_collation_returns_large_integer(self):
      92          def mycoll(x, y):
      93              # reverse order
      94              return -((x > y) - (x < y)) * 2**32
      95          con = sqlite.connect(":memory:")
      96          con.create_collation("mycoll", mycoll)
      97          sql = """
      98              select x from (
      99              select 'a' as x
     100              union
     101              select 'b' as x
     102              union
     103              select 'c' as x
     104              ) order by x collate mycoll
     105              """
     106          result = con.execute(sql).fetchall()
     107          self.assertEqual(result, [('c',), ('b',), ('a',)],
     108                           msg="the expected order was not returned")
     109  
     110      def test_collation_register_twice(self):
     111          """
     112          Register two different collation functions under the same name.
     113          Verify that the last one is actually used.
     114          """
     115          con = sqlite.connect(":memory:")
     116          con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
     117          con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
     118          result = con.execute("""
     119              select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
     120              """).fetchall()
     121          self.assertEqual(result[0][0], 'b')
     122          self.assertEqual(result[1][0], 'a')
     123  
     124      def test_deregister_collation(self):
     125          """
     126          Register a collation, then deregister it. Make sure an error is raised if we try
     127          to use it.
     128          """
     129          con = sqlite.connect(":memory:")
     130          con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
     131          con.create_collation("mycoll", None)
     132          with self.assertRaises(sqlite.OperationalError) as cm:
     133              con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
     134          self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
     135  
     136  class ESC[4;38;5;81mProgressTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     137      def test_progress_handler_used(self):
     138          """
     139          Test that the progress handler is invoked once it is set.
     140          """
     141          con = sqlite.connect(":memory:")
     142          progress_calls = []
     143          def progress():
     144              progress_calls.append(None)
     145              return 0
     146          con.set_progress_handler(progress, 1)
     147          con.execute("""
     148              create table foo(a, b)
     149              """)
     150          self.assertTrue(progress_calls)
     151  
     152      def test_opcode_count(self):
     153          """
     154          Test that the opcode argument is respected.
     155          """
     156          con = sqlite.connect(":memory:")
     157          progress_calls = []
     158          def progress():
     159              progress_calls.append(None)
     160              return 0
     161          con.set_progress_handler(progress, 1)
     162          curs = con.cursor()
     163          curs.execute("""
     164              create table foo (a, b)
     165              """)
     166          first_count = len(progress_calls)
     167          progress_calls = []
     168          con.set_progress_handler(progress, 2)
     169          curs.execute("""
     170              create table bar (a, b)
     171              """)
     172          second_count = len(progress_calls)
     173          self.assertGreaterEqual(first_count, second_count)
     174  
     175      def test_cancel_operation(self):
     176          """
     177          Test that returning a non-zero value stops the operation in progress.
     178          """
     179          con = sqlite.connect(":memory:")
     180          def progress():
     181              return 1
     182          con.set_progress_handler(progress, 1)
     183          curs = con.cursor()
     184          self.assertRaises(
     185              sqlite.OperationalError,
     186              curs.execute,
     187              "create table bar (a, b)")
     188  
     189      def test_clear_handler(self):
     190          """
     191          Test that setting the progress handler to None clears the previously set handler.
     192          """
     193          con = sqlite.connect(":memory:")
     194          action = 0
     195          def progress():
     196              nonlocal action
     197              action = 1
     198              return 0
     199          con.set_progress_handler(progress, 1)
     200          con.set_progress_handler(None, 1)
     201          con.execute("select 1 union select 2 union select 3").fetchall()
     202          self.assertEqual(action, 0, "progress handler was not cleared")
     203  
     204      @with_tracebacks(ZeroDivisionError, name="bad_progress")
     205      def test_error_in_progress_handler(self):
     206          con = sqlite.connect(":memory:")
     207          def bad_progress():
     208              1 / 0
     209          con.set_progress_handler(bad_progress, 1)
     210          with self.assertRaises(sqlite.OperationalError):
     211              con.execute("""
     212                  create table foo(a, b)
     213                  """)
     214  
     215      @with_tracebacks(ZeroDivisionError, name="bad_progress")
     216      def test_error_in_progress_handler_result(self):
     217          con = sqlite.connect(":memory:")
     218          class ESC[4;38;5;81mBadBool:
     219              def __bool__(self):
     220                  1 / 0
     221          def bad_progress():
     222              return BadBool()
     223          con.set_progress_handler(bad_progress, 1)
     224          with self.assertRaises(sqlite.OperationalError):
     225              con.execute("""
     226                  create table foo(a, b)
     227                  """)
     228  
     229  
     230  class ESC[4;38;5;81mTraceCallbackTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     231      @contextlib.contextmanager
     232      def check_stmt_trace(self, cx, expected):
     233          try:
     234              traced = []
     235              cx.set_trace_callback(lambda stmt: traced.append(stmt))
     236              yield
     237          finally:
     238              self.assertEqual(traced, expected)
     239              cx.set_trace_callback(None)
     240  
     241      def test_trace_callback_used(self):
     242          """
     243          Test that the trace callback is invoked once it is set.
     244          """
     245          con = sqlite.connect(":memory:")
     246          traced_statements = []
     247          def trace(statement):
     248              traced_statements.append(statement)
     249          con.set_trace_callback(trace)
     250          con.execute("create table foo(a, b)")
     251          self.assertTrue(traced_statements)
     252          self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
     253  
     254      def test_clear_trace_callback(self):
     255          """
     256          Test that setting the trace callback to None clears the previously set callback.
     257          """
     258          con = sqlite.connect(":memory:")
     259          traced_statements = []
     260          def trace(statement):
     261              traced_statements.append(statement)
     262          con.set_trace_callback(trace)
     263          con.set_trace_callback(None)
     264          con.execute("create table foo(a, b)")
     265          self.assertFalse(traced_statements, "trace callback was not cleared")
     266  
     267      def test_unicode_content(self):
     268          """
     269          Test that the statement can contain unicode literals.
     270          """
     271          unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
     272          con = sqlite.connect(":memory:")
     273          traced_statements = []
     274          def trace(statement):
     275              traced_statements.append(statement)
     276          con.set_trace_callback(trace)
     277          con.execute("create table foo(x)")
     278          con.execute("insert into foo(x) values ('%s')" % unicode_value)
     279          con.commit()
     280          self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
     281                          "Unicode data %s garbled in trace callback: %s"
     282                          % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
     283  
     284      def test_trace_callback_content(self):
     285          # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
     286          traced_statements = []
     287          def trace(statement):
     288              traced_statements.append(statement)
     289  
     290          queries = ["create table foo(x)",
     291                     "insert into foo(x) values(1)"]
     292          self.addCleanup(unlink, TESTFN)
     293          con1 = sqlite.connect(TESTFN, isolation_level=None)
     294          con2 = sqlite.connect(TESTFN)
     295          try:
     296              con1.set_trace_callback(trace)
     297              cur = con1.cursor()
     298              cur.execute(queries[0])
     299              con2.execute("create table bar(x)")
     300              cur.execute(queries[1])
     301          finally:
     302              con1.close()
     303              con2.close()
     304          self.assertEqual(traced_statements, queries)
     305  
     306      def test_trace_expanded_sql(self):
     307          expected = [
     308              "create table t(t)",
     309              "BEGIN ",
     310              "insert into t values(0)",
     311              "insert into t values(1)",
     312              "insert into t values(2)",
     313              "COMMIT",
     314          ]
     315          with memory_database() as cx, self.check_stmt_trace(cx, expected):
     316              with cx:
     317                  cx.execute("create table t(t)")
     318                  cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
     319  
     320      @with_tracebacks(
     321          sqlite.DataError,
     322          regex="Expanded SQL string exceeds the maximum string length"
     323      )
     324      def test_trace_too_much_expanded_sql(self):
     325          # If the expanded string is too large, we'll fall back to the
     326          # unexpanded SQL statement (for SQLite 3.14.0 and newer).
     327          # The resulting string length is limited by the runtime limit
     328          # SQLITE_LIMIT_LENGTH.
     329          template = "select 1 as a where a="
     330          category = sqlite.SQLITE_LIMIT_LENGTH
     331          with memory_database() as cx, cx_limit(cx, category=category) as lim:
     332              ok_param = "a"
     333              bad_param = "a" * lim
     334  
     335              unexpanded_query = template + "?"
     336              expected = [unexpanded_query]
     337              if sqlite.sqlite_version_info < (3, 14, 0):
     338                  expected = []
     339              with self.check_stmt_trace(cx, expected):
     340                  cx.execute(unexpanded_query, (bad_param,))
     341  
     342              expanded_query = f"{template}'{ok_param}'"
     343              with self.check_stmt_trace(cx, [expanded_query]):
     344                  cx.execute(unexpanded_query, (ok_param,))
     345  
     346      @with_tracebacks(ZeroDivisionError, regex="division by zero")
     347      def test_trace_bad_handler(self):
     348          with memory_database() as cx:
     349              cx.set_trace_callback(lambda stmt: 5/0)
     350              cx.execute("select 1")
     351  
     352  
     353  if __name__ == "__main__":
     354      unittest.main()