1 from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2 NAME_MAPPING, REVERSE_NAME_MAPPING)
3 import builtins
4 import pickle
5 import io
6 import collections
7 import struct
8 import sys
9 import warnings
10 import weakref
11
12 import doctest
13 import unittest
14 from test import support
15 from test.support import import_helper
16
17 from test.pickletester import AbstractHookTests
18 from test.pickletester import AbstractUnpickleTests
19 from test.pickletester import AbstractPickleTests
20 from test.pickletester import AbstractPickleModuleTests
21 from test.pickletester import AbstractPersistentPicklerTests
22 from test.pickletester import AbstractIdentityPersistentPicklerTests
23 from test.pickletester import AbstractPicklerUnpicklerObjectTests
24 from test.pickletester import AbstractDispatchTableTests
25 from test.pickletester import AbstractCustomPicklerClass
26 from test.pickletester import BigmemPickleTests
27
28 try:
29 import _pickle
30 has_c_implementation = True
31 except ImportError:
32 has_c_implementation = False
33
34
35 class ESC[4;38;5;81mPyPickleTests(ESC[4;38;5;149mAbstractPickleModuleTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
36 dump = staticmethod(pickle._dump)
37 dumps = staticmethod(pickle._dumps)
38 load = staticmethod(pickle._load)
39 loads = staticmethod(pickle._loads)
40 Pickler = pickle._Pickler
41 Unpickler = pickle._Unpickler
42
43
44 class ESC[4;38;5;81mPyUnpicklerTests(ESC[4;38;5;149mAbstractUnpickleTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
45
46 unpickler = pickle._Unpickler
47 bad_stack_errors = (IndexError,)
48 truncated_errors = (pickle.UnpicklingError, EOFError,
49 AttributeError, ValueError,
50 struct.error, IndexError, ImportError)
51
52 def loads(self, buf, **kwds):
53 f = io.BytesIO(buf)
54 u = self.unpickler(f, **kwds)
55 return u.load()
56
57
58 class ESC[4;38;5;81mPyPicklerTests(ESC[4;38;5;149mAbstractPickleTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
59
60 pickler = pickle._Pickler
61 unpickler = pickle._Unpickler
62
63 def dumps(self, arg, proto=None, **kwargs):
64 f = io.BytesIO()
65 p = self.pickler(f, proto, **kwargs)
66 p.dump(arg)
67 f.seek(0)
68 return bytes(f.read())
69
70 def loads(self, buf, **kwds):
71 f = io.BytesIO(buf)
72 u = self.unpickler(f, **kwds)
73 return u.load()
74
75
76 class ESC[4;38;5;81mInMemoryPickleTests(ESC[4;38;5;149mAbstractPickleTests, ESC[4;38;5;149mAbstractUnpickleTests,
77 ESC[4;38;5;149mBigmemPickleTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
78
79 bad_stack_errors = (pickle.UnpicklingError, IndexError)
80 truncated_errors = (pickle.UnpicklingError, EOFError,
81 AttributeError, ValueError,
82 struct.error, IndexError, ImportError)
83
84 def dumps(self, arg, protocol=None, **kwargs):
85 return pickle.dumps(arg, protocol, **kwargs)
86
87 def loads(self, buf, **kwds):
88 return pickle.loads(buf, **kwds)
89
90 test_framed_write_sizes_with_delayed_writer = None
91
92
93 class ESC[4;38;5;81mPersistentPicklerUnpicklerMixin(ESC[4;38;5;149mobject):
94
95 def dumps(self, arg, proto=None):
96 class ESC[4;38;5;81mPersPickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149mpickler):
97 def persistent_id(subself, obj):
98 return self.persistent_id(obj)
99 f = io.BytesIO()
100 p = PersPickler(f, proto)
101 p.dump(arg)
102 return f.getvalue()
103
104 def loads(self, buf, **kwds):
105 class ESC[4;38;5;81mPersUnpickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149munpickler):
106 def persistent_load(subself, obj):
107 return self.persistent_load(obj)
108 f = io.BytesIO(buf)
109 u = PersUnpickler(f, **kwds)
110 return u.load()
111
112
113 class ESC[4;38;5;81mPyPersPicklerTests(ESC[4;38;5;149mAbstractPersistentPicklerTests,
114 ESC[4;38;5;149mPersistentPicklerUnpicklerMixin, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
115
116 pickler = pickle._Pickler
117 unpickler = pickle._Unpickler
118
119
120 class ESC[4;38;5;81mPyIdPersPicklerTests(ESC[4;38;5;149mAbstractIdentityPersistentPicklerTests,
121 ESC[4;38;5;149mPersistentPicklerUnpicklerMixin, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
122
123 pickler = pickle._Pickler
124 unpickler = pickle._Unpickler
125
126 @support.cpython_only
127 def test_pickler_reference_cycle(self):
128 def check(Pickler):
129 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
130 f = io.BytesIO()
131 pickler = Pickler(f, proto)
132 pickler.dump('abc')
133 self.assertEqual(self.loads(f.getvalue()), 'abc')
134 pickler = Pickler(io.BytesIO())
135 self.assertEqual(pickler.persistent_id('def'), 'def')
136 r = weakref.ref(pickler)
137 del pickler
138 self.assertIsNone(r())
139
140 class ESC[4;38;5;81mPersPickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149mpickler):
141 def persistent_id(subself, obj):
142 return obj
143 check(PersPickler)
144
145 class ESC[4;38;5;81mPersPickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149mpickler):
146 @classmethod
147 def persistent_id(cls, obj):
148 return obj
149 check(PersPickler)
150
151 class ESC[4;38;5;81mPersPickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149mpickler):
152 @staticmethod
153 def persistent_id(obj):
154 return obj
155 check(PersPickler)
156
157 @support.cpython_only
158 def test_custom_pickler_dispatch_table_memleak(self):
159 # See https://github.com/python/cpython/issues/89988
160
161 class ESC[4;38;5;81mPickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149mpickler):
162 def __init__(self, *args, **kwargs):
163 self.dispatch_table = table
164 super().__init__(*args, **kwargs)
165
166 class ESC[4;38;5;81mDispatchTable:
167 pass
168
169 table = DispatchTable()
170 pickler = Pickler(io.BytesIO())
171 self.assertIs(pickler.dispatch_table, table)
172 table_ref = weakref.ref(table)
173 self.assertIsNotNone(table_ref())
174 del pickler
175 del table
176 support.gc_collect()
177 self.assertIsNone(table_ref())
178
179
180 @support.cpython_only
181 def test_unpickler_reference_cycle(self):
182 def check(Unpickler):
183 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
184 unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
185 self.assertEqual(unpickler.load(), 'abc')
186 unpickler = Unpickler(io.BytesIO())
187 self.assertEqual(unpickler.persistent_load('def'), 'def')
188 r = weakref.ref(unpickler)
189 del unpickler
190 self.assertIsNone(r())
191
192 class ESC[4;38;5;81mPersUnpickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149munpickler):
193 def persistent_load(subself, pid):
194 return pid
195 check(PersUnpickler)
196
197 class ESC[4;38;5;81mPersUnpickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149munpickler):
198 @classmethod
199 def persistent_load(cls, pid):
200 return pid
201 check(PersUnpickler)
202
203 class ESC[4;38;5;81mPersUnpickler(ESC[4;38;5;149mselfESC[4;38;5;149m.ESC[4;38;5;149munpickler):
204 @staticmethod
205 def persistent_load(pid):
206 return pid
207 check(PersUnpickler)
208
209
210 class ESC[4;38;5;81mPyPicklerUnpicklerObjectTests(ESC[4;38;5;149mAbstractPicklerUnpicklerObjectTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
211
212 pickler_class = pickle._Pickler
213 unpickler_class = pickle._Unpickler
214
215
216 class ESC[4;38;5;81mPyDispatchTableTests(ESC[4;38;5;149mAbstractDispatchTableTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
217
218 pickler_class = pickle._Pickler
219
220 def get_dispatch_table(self):
221 return pickle.dispatch_table.copy()
222
223
224 class ESC[4;38;5;81mPyChainDispatchTableTests(ESC[4;38;5;149mAbstractDispatchTableTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
225
226 pickler_class = pickle._Pickler
227
228 def get_dispatch_table(self):
229 return collections.ChainMap({}, pickle.dispatch_table)
230
231
232 class ESC[4;38;5;81mPyPicklerHookTests(ESC[4;38;5;149mAbstractHookTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
233 class ESC[4;38;5;81mCustomPyPicklerClass(ESC[4;38;5;149mpickleESC[4;38;5;149m.ESC[4;38;5;149m_Pickler,
234 ESC[4;38;5;149mAbstractCustomPicklerClass):
235 pass
236 pickler_class = CustomPyPicklerClass
237
238
239 if has_c_implementation:
240 class ESC[4;38;5;81mCPickleTests(ESC[4;38;5;149mAbstractPickleModuleTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
241 from _pickle import dump, dumps, load, loads, Pickler, Unpickler
242
243 class ESC[4;38;5;81mCUnpicklerTests(ESC[4;38;5;149mPyUnpicklerTests):
244 unpickler = _pickle.Unpickler
245 bad_stack_errors = (pickle.UnpicklingError,)
246 truncated_errors = (pickle.UnpicklingError,)
247
248 class ESC[4;38;5;81mCPicklerTests(ESC[4;38;5;149mPyPicklerTests):
249 pickler = _pickle.Pickler
250 unpickler = _pickle.Unpickler
251
252 class ESC[4;38;5;81mCPersPicklerTests(ESC[4;38;5;149mPyPersPicklerTests):
253 pickler = _pickle.Pickler
254 unpickler = _pickle.Unpickler
255
256 class ESC[4;38;5;81mCIdPersPicklerTests(ESC[4;38;5;149mPyIdPersPicklerTests):
257 pickler = _pickle.Pickler
258 unpickler = _pickle.Unpickler
259
260 class ESC[4;38;5;81mCDumpPickle_LoadPickle(ESC[4;38;5;149mPyPicklerTests):
261 pickler = _pickle.Pickler
262 unpickler = pickle._Unpickler
263
264 class ESC[4;38;5;81mDumpPickle_CLoadPickle(ESC[4;38;5;149mPyPicklerTests):
265 pickler = pickle._Pickler
266 unpickler = _pickle.Unpickler
267
268 class ESC[4;38;5;81mCPicklerUnpicklerObjectTests(ESC[4;38;5;149mAbstractPicklerUnpicklerObjectTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
269 pickler_class = _pickle.Pickler
270 unpickler_class = _pickle.Unpickler
271
272 def test_issue18339(self):
273 unpickler = self.unpickler_class(io.BytesIO())
274 with self.assertRaises(TypeError):
275 unpickler.memo = object
276 # used to cause a segfault
277 with self.assertRaises(ValueError):
278 unpickler.memo = {-1: None}
279 unpickler.memo = {1: None}
280
281 class ESC[4;38;5;81mCDispatchTableTests(ESC[4;38;5;149mAbstractDispatchTableTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
282 pickler_class = pickle.Pickler
283 def get_dispatch_table(self):
284 return pickle.dispatch_table.copy()
285
286 class ESC[4;38;5;81mCChainDispatchTableTests(ESC[4;38;5;149mAbstractDispatchTableTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
287 pickler_class = pickle.Pickler
288 def get_dispatch_table(self):
289 return collections.ChainMap({}, pickle.dispatch_table)
290
291 class ESC[4;38;5;81mCPicklerHookTests(ESC[4;38;5;149mAbstractHookTests, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
292 class ESC[4;38;5;81mCustomCPicklerClass(ESC[4;38;5;149m_pickleESC[4;38;5;149m.ESC[4;38;5;149mPickler, ESC[4;38;5;149mAbstractCustomPicklerClass):
293 pass
294 pickler_class = CustomCPicklerClass
295
296 @support.cpython_only
297 class ESC[4;38;5;81mSizeofTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
298 check_sizeof = support.check_sizeof
299
300 def test_pickler(self):
301 basesize = support.calcobjsize('7P2n3i2n3i2P')
302 p = _pickle.Pickler(io.BytesIO())
303 self.assertEqual(object.__sizeof__(p), basesize)
304 MT_size = struct.calcsize('3nP0n')
305 ME_size = struct.calcsize('Pn0P')
306 check = self.check_sizeof
307 check(p, basesize +
308 MT_size + 8 * ME_size + # Minimal memo table size.
309 sys.getsizeof(b'x'*4096)) # Minimal write buffer size.
310 for i in range(6):
311 p.dump(chr(i))
312 check(p, basesize +
313 MT_size + 32 * ME_size + # Size of memo table required to
314 # save references to 6 objects.
315 0) # Write buffer is cleared after every dump().
316
317 def test_unpickler(self):
318 basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
319 unpickler = _pickle.Unpickler
320 P = struct.calcsize('P') # Size of memo table entry.
321 n = struct.calcsize('n') # Size of mark table entry.
322 check = self.check_sizeof
323 for encoding in 'ASCII', 'UTF-16', 'latin-1':
324 for errors in 'strict', 'replace':
325 u = unpickler(io.BytesIO(),
326 encoding=encoding, errors=errors)
327 self.assertEqual(object.__sizeof__(u), basesize)
328 check(u, basesize +
329 32 * P + # Minimal memo table size.
330 len(encoding) + 1 + len(errors) + 1)
331
332 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
333 def check_unpickler(data, memo_size, marks_size):
334 dump = pickle.dumps(data)
335 u = unpickler(io.BytesIO(dump),
336 encoding='ASCII', errors='strict')
337 u.load()
338 check(u, stdsize + memo_size * P + marks_size * n)
339
340 check_unpickler(0, 32, 0)
341 # 20 is minimal non-empty mark stack size.
342 check_unpickler([0] * 100, 32, 20)
343 # 128 is memo table size required to save references to 100 objects.
344 check_unpickler([chr(i) for i in range(100)], 128, 20)
345 def recurse(deep):
346 data = 0
347 for i in range(deep):
348 data = [data, data]
349 return data
350 check_unpickler(recurse(0), 32, 0)
351 check_unpickler(recurse(1), 32, 20)
352 check_unpickler(recurse(20), 32, 20)
353 check_unpickler(recurse(50), 64, 60)
354 check_unpickler(recurse(100), 128, 140)
355
356 u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
357 encoding='ASCII', errors='strict')
358 u.load()
359 check(u, stdsize + 32 * P + 2 + 1)
360
361
362 ALT_IMPORT_MAPPING = {
363 ('_elementtree', 'xml.etree.ElementTree'),
364 ('cPickle', 'pickle'),
365 ('StringIO', 'io'),
366 ('cStringIO', 'io'),
367 }
368
369 ALT_NAME_MAPPING = {
370 ('__builtin__', 'basestring', 'builtins', 'str'),
371 ('exceptions', 'StandardError', 'builtins', 'Exception'),
372 ('UserDict', 'UserDict', 'collections', 'UserDict'),
373 ('socket', '_socketobject', 'socket', 'SocketType'),
374 }
375
376 def mapping(module, name):
377 if (module, name) in NAME_MAPPING:
378 module, name = NAME_MAPPING[(module, name)]
379 elif module in IMPORT_MAPPING:
380 module = IMPORT_MAPPING[module]
381 return module, name
382
383 def reverse_mapping(module, name):
384 if (module, name) in REVERSE_NAME_MAPPING:
385 module, name = REVERSE_NAME_MAPPING[(module, name)]
386 elif module in REVERSE_IMPORT_MAPPING:
387 module = REVERSE_IMPORT_MAPPING[module]
388 return module, name
389
390 def getmodule(module):
391 try:
392 return sys.modules[module]
393 except KeyError:
394 try:
395 with warnings.catch_warnings():
396 action = 'always' if support.verbose else 'ignore'
397 warnings.simplefilter(action, DeprecationWarning)
398 __import__(module)
399 except AttributeError as exc:
400 if support.verbose:
401 print("Can't import module %r: %s" % (module, exc))
402 raise ImportError
403 except ImportError as exc:
404 if support.verbose:
405 print(exc)
406 raise
407 return sys.modules[module]
408
409 def getattribute(module, name):
410 obj = getmodule(module)
411 for n in name.split('.'):
412 obj = getattr(obj, n)
413 return obj
414
415 def get_exceptions(mod):
416 for name in dir(mod):
417 attr = getattr(mod, name)
418 if isinstance(attr, type) and issubclass(attr, BaseException):
419 yield name, attr
420
421 class ESC[4;38;5;81mCompatPickleTests(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
422 def test_import(self):
423 modules = set(IMPORT_MAPPING.values())
424 modules |= set(REVERSE_IMPORT_MAPPING)
425 modules |= {module for module, name in REVERSE_NAME_MAPPING}
426 modules |= {module for module, name in NAME_MAPPING.values()}
427 for module in modules:
428 try:
429 getmodule(module)
430 except ImportError:
431 pass
432
433 def test_import_mapping(self):
434 for module3, module2 in REVERSE_IMPORT_MAPPING.items():
435 with self.subTest((module3, module2)):
436 try:
437 getmodule(module3)
438 except ImportError:
439 pass
440 if module3[:1] != '_':
441 self.assertIn(module2, IMPORT_MAPPING)
442 self.assertEqual(IMPORT_MAPPING[module2], module3)
443
444 def test_name_mapping(self):
445 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
446 with self.subTest(((module3, name3), (module2, name2))):
447 if (module2, name2) == ('exceptions', 'OSError'):
448 attr = getattribute(module3, name3)
449 self.assertTrue(issubclass(attr, OSError))
450 elif (module2, name2) == ('exceptions', 'ImportError'):
451 attr = getattribute(module3, name3)
452 self.assertTrue(issubclass(attr, ImportError))
453 else:
454 module, name = mapping(module2, name2)
455 if module3[:1] != '_':
456 self.assertEqual((module, name), (module3, name3))
457 try:
458 attr = getattribute(module3, name3)
459 except ImportError:
460 pass
461 else:
462 self.assertEqual(getattribute(module, name), attr)
463
464 def test_reverse_import_mapping(self):
465 for module2, module3 in IMPORT_MAPPING.items():
466 with self.subTest((module2, module3)):
467 try:
468 getmodule(module3)
469 except ImportError as exc:
470 if support.verbose:
471 print(exc)
472 if ((module2, module3) not in ALT_IMPORT_MAPPING and
473 REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
474 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
475 if (module3, module2) == (m3, m2):
476 break
477 else:
478 self.fail('No reverse mapping from %r to %r' %
479 (module3, module2))
480 module = REVERSE_IMPORT_MAPPING.get(module3, module3)
481 module = IMPORT_MAPPING.get(module, module)
482 self.assertEqual(module, module3)
483
484 def test_reverse_name_mapping(self):
485 for (module2, name2), (module3, name3) in NAME_MAPPING.items():
486 with self.subTest(((module2, name2), (module3, name3))):
487 try:
488 attr = getattribute(module3, name3)
489 except ImportError:
490 pass
491 module, name = reverse_mapping(module3, name3)
492 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
493 self.assertEqual((module, name), (module2, name2))
494 module, name = mapping(module, name)
495 self.assertEqual((module, name), (module3, name3))
496
497 def test_exceptions(self):
498 self.assertEqual(mapping('exceptions', 'StandardError'),
499 ('builtins', 'Exception'))
500 self.assertEqual(mapping('exceptions', 'Exception'),
501 ('builtins', 'Exception'))
502 self.assertEqual(reverse_mapping('builtins', 'Exception'),
503 ('exceptions', 'Exception'))
504 self.assertEqual(mapping('exceptions', 'OSError'),
505 ('builtins', 'OSError'))
506 self.assertEqual(reverse_mapping('builtins', 'OSError'),
507 ('exceptions', 'OSError'))
508
509 for name, exc in get_exceptions(builtins):
510 with self.subTest(name):
511 if exc in (BlockingIOError,
512 ResourceWarning,
513 StopAsyncIteration,
514 RecursionError,
515 EncodingWarning,
516 BaseExceptionGroup,
517 ExceptionGroup):
518 continue
519 if exc is not OSError and issubclass(exc, OSError):
520 self.assertEqual(reverse_mapping('builtins', name),
521 ('exceptions', 'OSError'))
522 elif exc is not ImportError and issubclass(exc, ImportError):
523 self.assertEqual(reverse_mapping('builtins', name),
524 ('exceptions', 'ImportError'))
525 self.assertEqual(mapping('exceptions', name),
526 ('exceptions', name))
527 else:
528 self.assertEqual(reverse_mapping('builtins', name),
529 ('exceptions', name))
530 self.assertEqual(mapping('exceptions', name),
531 ('builtins', name))
532
533 def test_multiprocessing_exceptions(self):
534 module = import_helper.import_module('multiprocessing.context')
535 for name, exc in get_exceptions(module):
536 with self.subTest(name):
537 self.assertEqual(reverse_mapping('multiprocessing.context', name),
538 ('multiprocessing', name))
539 self.assertEqual(mapping('multiprocessing', name),
540 ('multiprocessing.context', name))
541
542
543 def load_tests(loader, tests, pattern):
544 tests.addTest(doctest.DocTestSuite())
545 return tests
546
547
548 if __name__ == "__main__":
549 unittest.main()