1 # pysqlite2/test/factory.py: tests for the various factories in pysqlite
2 #
3 # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
4 #
5 # This file is part of pysqlite.
6 #
7 # This software is provided 'as-is', without any express or implied
8 # warranty. In no event will the authors be held liable for any damages
9 # arising from the use of this software.
10 #
11 # Permission is granted to anyone to use this software for any purpose,
12 # including commercial applications, and to alter it and redistribute it
13 # freely, subject to the following restrictions:
14 #
15 # 1. The origin of this software must not be misrepresented; you must not
16 # claim that you wrote the original software. If you use this software
17 # in a product, an acknowledgment in the product documentation would be
18 # appreciated but is not required.
19 # 2. Altered source versions must be plainly marked as such, and must not be
20 # misrepresented as being the original software.
21 # 3. This notice may not be removed or altered from any source distribution.
22
23 import unittest
24 import sqlite3 as sqlite
25 from collections.abc import Sequence
26
27
28 def dict_factory(cursor, row):
29 d = {}
30 for idx, col in enumerate(cursor.description):
31 d[col[0]] = row[idx]
32 return d
33
34 class ESC[4;38;5;81mMyCursor(ESC[4;38;5;149msqliteESC[4;38;5;149m.ESC[4;38;5;149mCursor):
35 def __init__(self, *args, **kwargs):
36 sqlite.Cursor.__init__(self, *args, **kwargs)
37 self.row_factory = dict_factory
38
39 class ESC[4;38;5;81mConnectionFactoryTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
40 def test_connection_factories(self):
41 class ESC[4;38;5;81mDefectFactory(ESC[4;38;5;149msqliteESC[4;38;5;149m.ESC[4;38;5;149mConnection):
42 def __init__(self, *args, **kwargs):
43 return None
44 class ESC[4;38;5;81mOkFactory(ESC[4;38;5;149msqliteESC[4;38;5;149m.ESC[4;38;5;149mConnection):
45 def __init__(self, *args, **kwargs):
46 sqlite.Connection.__init__(self, *args, **kwargs)
47
48 for factory in DefectFactory, OkFactory:
49 with self.subTest(factory=factory):
50 con = sqlite.connect(":memory:", factory=factory)
51 self.assertIsInstance(con, factory)
52
53 def test_connection_factory_relayed_call(self):
54 # gh-95132: keyword args must not be passed as positional args
55 class ESC[4;38;5;81mFactory(ESC[4;38;5;149msqliteESC[4;38;5;149m.ESC[4;38;5;149mConnection):
56 def __init__(self, *args, **kwargs):
57 kwargs["isolation_level"] = None
58 super(Factory, self).__init__(*args, **kwargs)
59
60 con = sqlite.connect(":memory:", factory=Factory)
61 self.assertIsNone(con.isolation_level)
62 self.assertIsInstance(con, Factory)
63
64 def test_connection_factory_as_positional_arg(self):
65 class ESC[4;38;5;81mFactory(ESC[4;38;5;149msqliteESC[4;38;5;149m.ESC[4;38;5;149mConnection):
66 def __init__(self, *args, **kwargs):
67 super(Factory, self).__init__(*args, **kwargs)
68
69 con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory)
70 self.assertIsNone(con.isolation_level)
71 self.assertIsInstance(con, Factory)
72
73
74 class ESC[4;38;5;81mCursorFactoryTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
75 def setUp(self):
76 self.con = sqlite.connect(":memory:")
77
78 def tearDown(self):
79 self.con.close()
80
81 def test_is_instance(self):
82 cur = self.con.cursor()
83 self.assertIsInstance(cur, sqlite.Cursor)
84 cur = self.con.cursor(MyCursor)
85 self.assertIsInstance(cur, MyCursor)
86 cur = self.con.cursor(factory=lambda con: MyCursor(con))
87 self.assertIsInstance(cur, MyCursor)
88
89 def test_invalid_factory(self):
90 # not a callable at all
91 self.assertRaises(TypeError, self.con.cursor, None)
92 # invalid callable with not exact one argument
93 self.assertRaises(TypeError, self.con.cursor, lambda: None)
94 # invalid callable returning non-cursor
95 self.assertRaises(TypeError, self.con.cursor, lambda con: None)
96
97 class ESC[4;38;5;81mRowFactoryTestsBackwardsCompat(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
98 def setUp(self):
99 self.con = sqlite.connect(":memory:")
100
101 def test_is_produced_by_factory(self):
102 cur = self.con.cursor(factory=MyCursor)
103 cur.execute("select 4+5 as foo")
104 row = cur.fetchone()
105 self.assertIsInstance(row, dict)
106 cur.close()
107
108 def tearDown(self):
109 self.con.close()
110
111 class ESC[4;38;5;81mRowFactoryTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
112 def setUp(self):
113 self.con = sqlite.connect(":memory:")
114 self.con.row_factory = sqlite.Row
115
116 def test_custom_factory(self):
117 self.con.row_factory = lambda cur, row: list(row)
118 row = self.con.execute("select 1, 2").fetchone()
119 self.assertIsInstance(row, list)
120
121 def test_sqlite_row_index(self):
122 row = self.con.execute("select 1 as a_1, 2 as b").fetchone()
123 self.assertIsInstance(row, sqlite.Row)
124
125 self.assertEqual(row["a_1"], 1, "by name: wrong result for column 'a_1'")
126 self.assertEqual(row["b"], 2, "by name: wrong result for column 'b'")
127
128 self.assertEqual(row["A_1"], 1, "by name: wrong result for column 'A_1'")
129 self.assertEqual(row["B"], 2, "by name: wrong result for column 'B'")
130
131 self.assertEqual(row[0], 1, "by index: wrong result for column 0")
132 self.assertEqual(row[1], 2, "by index: wrong result for column 1")
133 self.assertEqual(row[-1], 2, "by index: wrong result for column -1")
134 self.assertEqual(row[-2], 1, "by index: wrong result for column -2")
135
136 with self.assertRaises(IndexError):
137 row['c']
138 with self.assertRaises(IndexError):
139 row['a_\x11']
140 with self.assertRaises(IndexError):
141 row['a\x7f1']
142 with self.assertRaises(IndexError):
143 row[2]
144 with self.assertRaises(IndexError):
145 row[-3]
146 with self.assertRaises(IndexError):
147 row[2**1000]
148 with self.assertRaises(IndexError):
149 row[complex()] # index must be int or string
150
151 def test_sqlite_row_index_unicode(self):
152 row = self.con.execute("select 1 as \xff").fetchone()
153 self.assertEqual(row["\xff"], 1)
154 with self.assertRaises(IndexError):
155 row['\u0178']
156 with self.assertRaises(IndexError):
157 row['\xdf']
158
159 def test_sqlite_row_slice(self):
160 # A sqlite.Row can be sliced like a list.
161 row = self.con.execute("select 1, 2, 3, 4").fetchone()
162 self.assertEqual(row[0:0], ())
163 self.assertEqual(row[0:1], (1,))
164 self.assertEqual(row[1:3], (2, 3))
165 self.assertEqual(row[3:1], ())
166 # Explicit bounds are optional.
167 self.assertEqual(row[1:], (2, 3, 4))
168 self.assertEqual(row[:3], (1, 2, 3))
169 # Slices can use negative indices.
170 self.assertEqual(row[-2:-1], (3,))
171 self.assertEqual(row[-2:], (3, 4))
172 # Slicing supports steps.
173 self.assertEqual(row[0:4:2], (1, 3))
174 self.assertEqual(row[3:0:-2], (4, 2))
175
176 def test_sqlite_row_iter(self):
177 # Checks if the row object is iterable.
178 row = self.con.execute("select 1 as a, 2 as b").fetchone()
179
180 # Is iterable in correct order and produces valid results:
181 items = [col for col in row]
182 self.assertEqual(items, [1, 2])
183
184 # Is iterable the second time:
185 items = [col for col in row]
186 self.assertEqual(items, [1, 2])
187
188 def test_sqlite_row_as_tuple(self):
189 # Checks if the row object can be converted to a tuple.
190 row = self.con.execute("select 1 as a, 2 as b").fetchone()
191 t = tuple(row)
192 self.assertEqual(t, (row['a'], row['b']))
193
194 def test_sqlite_row_as_dict(self):
195 # Checks if the row object can be correctly converted to a dictionary.
196 row = self.con.execute("select 1 as a, 2 as b").fetchone()
197 d = dict(row)
198 self.assertEqual(d["a"], row["a"])
199 self.assertEqual(d["b"], row["b"])
200
201 def test_sqlite_row_hash_cmp(self):
202 # Checks if the row object compares and hashes correctly.
203 row_1 = self.con.execute("select 1 as a, 2 as b").fetchone()
204 row_2 = self.con.execute("select 1 as a, 2 as b").fetchone()
205 row_3 = self.con.execute("select 1 as a, 3 as b").fetchone()
206 row_4 = self.con.execute("select 1 as b, 2 as a").fetchone()
207 row_5 = self.con.execute("select 2 as b, 1 as a").fetchone()
208
209 self.assertTrue(row_1 == row_1)
210 self.assertTrue(row_1 == row_2)
211 self.assertFalse(row_1 == row_3)
212 self.assertFalse(row_1 == row_4)
213 self.assertFalse(row_1 == row_5)
214 self.assertFalse(row_1 == object())
215
216 self.assertFalse(row_1 != row_1)
217 self.assertFalse(row_1 != row_2)
218 self.assertTrue(row_1 != row_3)
219 self.assertTrue(row_1 != row_4)
220 self.assertTrue(row_1 != row_5)
221 self.assertTrue(row_1 != object())
222
223 with self.assertRaises(TypeError):
224 row_1 > row_2
225 with self.assertRaises(TypeError):
226 row_1 < row_2
227 with self.assertRaises(TypeError):
228 row_1 >= row_2
229 with self.assertRaises(TypeError):
230 row_1 <= row_2
231
232 self.assertEqual(hash(row_1), hash(row_2))
233
234 def test_sqlite_row_as_sequence(self):
235 # Checks if the row object can act like a sequence.
236 row = self.con.execute("select 1 as a, 2 as b").fetchone()
237
238 as_tuple = tuple(row)
239 self.assertEqual(list(reversed(row)), list(reversed(as_tuple)))
240 self.assertIsInstance(row, Sequence)
241
242 def test_sqlite_row_keys(self):
243 # Checks if the row object can return a list of columns as strings.
244 row = self.con.execute("select 1 as a, 2 as b").fetchone()
245 self.assertEqual(row.keys(), ['a', 'b'])
246
247 def test_fake_cursor_class(self):
248 # Issue #24257: Incorrect use of PyObject_IsInstance() caused
249 # segmentation fault.
250 # Issue #27861: Also applies for cursor factory.
251 class ESC[4;38;5;81mFakeCursor(ESC[4;38;5;149mstr):
252 __class__ = sqlite.Cursor
253 self.assertRaises(TypeError, self.con.cursor, FakeCursor)
254 self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
255
256 def tearDown(self):
257 self.con.close()
258
259 class ESC[4;38;5;81mTextFactoryTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
260 def setUp(self):
261 self.con = sqlite.connect(":memory:")
262
263 def test_unicode(self):
264 austria = "Österreich"
265 row = self.con.execute("select ?", (austria,)).fetchone()
266 self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
267
268 def test_string(self):
269 self.con.text_factory = bytes
270 austria = "Österreich"
271 row = self.con.execute("select ?", (austria,)).fetchone()
272 self.assertEqual(type(row[0]), bytes, "type of row[0] must be bytes")
273 self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8")
274
275 def test_custom(self):
276 self.con.text_factory = lambda x: str(x, "utf-8", "ignore")
277 austria = "Österreich"
278 row = self.con.execute("select ?", (austria,)).fetchone()
279 self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
280 self.assertTrue(row[0].endswith("reich"), "column must contain original data")
281
282 def test_optimized_unicode(self):
283 # OptimizedUnicode is deprecated as of Python 3.10
284 with self.assertWarns(DeprecationWarning) as cm:
285 self.con.text_factory = sqlite.OptimizedUnicode
286 self.assertIn("factory.py", cm.filename)
287 austria = "Österreich"
288 germany = "Deutchland"
289 a_row = self.con.execute("select ?", (austria,)).fetchone()
290 d_row = self.con.execute("select ?", (germany,)).fetchone()
291 self.assertEqual(type(a_row[0]), str, "type of non-ASCII row must be str")
292 self.assertEqual(type(d_row[0]), str, "type of ASCII-only row must be str")
293
294 def tearDown(self):
295 self.con.close()
296
297 class ESC[4;38;5;81mTextFactoryTestsWithEmbeddedZeroBytes(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
298 def setUp(self):
299 self.con = sqlite.connect(":memory:")
300 self.con.execute("create table test (value text)")
301 self.con.execute("insert into test (value) values (?)", ("a\x00b",))
302
303 def test_string(self):
304 # text_factory defaults to str
305 row = self.con.execute("select value from test").fetchone()
306 self.assertIs(type(row[0]), str)
307 self.assertEqual(row[0], "a\x00b")
308
309 def test_bytes(self):
310 self.con.text_factory = bytes
311 row = self.con.execute("select value from test").fetchone()
312 self.assertIs(type(row[0]), bytes)
313 self.assertEqual(row[0], b"a\x00b")
314
315 def test_bytearray(self):
316 self.con.text_factory = bytearray
317 row = self.con.execute("select value from test").fetchone()
318 self.assertIs(type(row[0]), bytearray)
319 self.assertEqual(row[0], b"a\x00b")
320
321 def test_custom(self):
322 # A custom factory should receive a bytes argument
323 self.con.text_factory = lambda x: x
324 row = self.con.execute("select value from test").fetchone()
325 self.assertIs(type(row[0]), bytes)
326 self.assertEqual(row[0], b"a\x00b")
327
328 def tearDown(self):
329 self.con.close()
330
331
332 if __name__ == "__main__":
333 unittest.main()