1 """Unit tests for contextlib.py, and other context managers."""
2
3 import io
4 import os
5 import sys
6 import tempfile
7 import threading
8 import traceback
9 import unittest
10 from contextlib import * # Tests __all__
11 from test import support
12 from test.support import os_helper
13 from test.support.testcase import ExceptionIsLikeMixin
14 import weakref
15
16
17 class ESC[4;38;5;81mTestAbstractContextManager(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
18
19 def test_enter(self):
20 class ESC[4;38;5;81mDefaultEnter(ESC[4;38;5;149mAbstractContextManager):
21 def __exit__(self, *args):
22 super().__exit__(*args)
23
24 manager = DefaultEnter()
25 self.assertIs(manager.__enter__(), manager)
26
27 def test_exit_is_abstract(self):
28 class ESC[4;38;5;81mMissingExit(ESC[4;38;5;149mAbstractContextManager):
29 pass
30
31 with self.assertRaises(TypeError):
32 MissingExit()
33
34 def test_structural_subclassing(self):
35 class ESC[4;38;5;81mManagerFromScratch:
36 def __enter__(self):
37 return self
38 def __exit__(self, exc_type, exc_value, traceback):
39 return None
40
41 self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
42
43 class ESC[4;38;5;81mDefaultEnter(ESC[4;38;5;149mAbstractContextManager):
44 def __exit__(self, *args):
45 super().__exit__(*args)
46
47 self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
48
49 class ESC[4;38;5;81mNoEnter(ESC[4;38;5;149mManagerFromScratch):
50 __enter__ = None
51
52 self.assertFalse(issubclass(NoEnter, AbstractContextManager))
53
54 class ESC[4;38;5;81mNoExit(ESC[4;38;5;149mManagerFromScratch):
55 __exit__ = None
56
57 self.assertFalse(issubclass(NoExit, AbstractContextManager))
58
59
60 class ESC[4;38;5;81mContextManagerTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
61
62 def test_contextmanager_plain(self):
63 state = []
64 @contextmanager
65 def woohoo():
66 state.append(1)
67 yield 42
68 state.append(999)
69 with woohoo() as x:
70 self.assertEqual(state, [1])
71 self.assertEqual(x, 42)
72 state.append(x)
73 self.assertEqual(state, [1, 42, 999])
74
75 def test_contextmanager_finally(self):
76 state = []
77 @contextmanager
78 def woohoo():
79 state.append(1)
80 try:
81 yield 42
82 finally:
83 state.append(999)
84 with self.assertRaises(ZeroDivisionError):
85 with woohoo() as x:
86 self.assertEqual(state, [1])
87 self.assertEqual(x, 42)
88 state.append(x)
89 raise ZeroDivisionError()
90 self.assertEqual(state, [1, 42, 999])
91
92 def test_contextmanager_traceback(self):
93 @contextmanager
94 def f():
95 yield
96
97 try:
98 with f():
99 1/0
100 except ZeroDivisionError as e:
101 frames = traceback.extract_tb(e.__traceback__)
102
103 self.assertEqual(len(frames), 1)
104 self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
105 self.assertEqual(frames[0].line, '1/0')
106
107 # Repeat with RuntimeError (which goes through a different code path)
108 class ESC[4;38;5;81mRuntimeErrorSubclass(ESC[4;38;5;149mRuntimeError):
109 pass
110
111 try:
112 with f():
113 raise RuntimeErrorSubclass(42)
114 except RuntimeErrorSubclass as e:
115 frames = traceback.extract_tb(e.__traceback__)
116
117 self.assertEqual(len(frames), 1)
118 self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
119 self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
120
121 class ESC[4;38;5;81mStopIterationSubclass(ESC[4;38;5;149mStopIteration):
122 pass
123
124 for stop_exc in (
125 StopIteration('spam'),
126 StopIterationSubclass('spam'),
127 ):
128 with self.subTest(type=type(stop_exc)):
129 try:
130 with f():
131 raise stop_exc
132 except type(stop_exc) as e:
133 self.assertIs(e, stop_exc)
134 frames = traceback.extract_tb(e.__traceback__)
135 else:
136 self.fail(f'{stop_exc} was suppressed')
137
138 self.assertEqual(len(frames), 1)
139 self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
140 self.assertEqual(frames[0].line, 'raise stop_exc')
141
142 def test_contextmanager_no_reraise(self):
143 @contextmanager
144 def whee():
145 yield
146 ctx = whee()
147 ctx.__enter__()
148 # Calling __exit__ should not result in an exception
149 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
150
151 def test_contextmanager_trap_yield_after_throw(self):
152 @contextmanager
153 def whoo():
154 try:
155 yield
156 except:
157 yield
158 ctx = whoo()
159 ctx.__enter__()
160 self.assertRaises(
161 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
162 )
163
164 def test_contextmanager_except(self):
165 state = []
166 @contextmanager
167 def woohoo():
168 state.append(1)
169 try:
170 yield 42
171 except ZeroDivisionError as e:
172 state.append(e.args[0])
173 self.assertEqual(state, [1, 42, 999])
174 with woohoo() as x:
175 self.assertEqual(state, [1])
176 self.assertEqual(x, 42)
177 state.append(x)
178 raise ZeroDivisionError(999)
179 self.assertEqual(state, [1, 42, 999])
180
181 def test_contextmanager_except_stopiter(self):
182 @contextmanager
183 def woohoo():
184 yield
185
186 class ESC[4;38;5;81mStopIterationSubclass(ESC[4;38;5;149mStopIteration):
187 pass
188
189 for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
190 with self.subTest(type=type(stop_exc)):
191 try:
192 with woohoo():
193 raise stop_exc
194 except Exception as ex:
195 self.assertIs(ex, stop_exc)
196 else:
197 self.fail(f'{stop_exc} was suppressed')
198
199 def test_contextmanager_except_pep479(self):
200 code = """\
201 from __future__ import generator_stop
202 from contextlib import contextmanager
203 @contextmanager
204 def woohoo():
205 yield
206 """
207 locals = {}
208 exec(code, locals, locals)
209 woohoo = locals['woohoo']
210
211 stop_exc = StopIteration('spam')
212 try:
213 with woohoo():
214 raise stop_exc
215 except Exception as ex:
216 self.assertIs(ex, stop_exc)
217 else:
218 self.fail('StopIteration was suppressed')
219
220 def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
221 @contextmanager
222 def test_issue29692():
223 try:
224 yield
225 except Exception as exc:
226 raise RuntimeError('issue29692:Chained') from exc
227 try:
228 with test_issue29692():
229 raise ZeroDivisionError
230 except Exception as ex:
231 self.assertIs(type(ex), RuntimeError)
232 self.assertEqual(ex.args[0], 'issue29692:Chained')
233 self.assertIsInstance(ex.__cause__, ZeroDivisionError)
234
235 try:
236 with test_issue29692():
237 raise StopIteration('issue29692:Unchained')
238 except Exception as ex:
239 self.assertIs(type(ex), StopIteration)
240 self.assertEqual(ex.args[0], 'issue29692:Unchained')
241 self.assertIsNone(ex.__cause__)
242
243 def _create_contextmanager_attribs(self):
244 def attribs(**kw):
245 def decorate(func):
246 for k,v in kw.items():
247 setattr(func,k,v)
248 return func
249 return decorate
250 @contextmanager
251 @attribs(foo='bar')
252 def baz(spam):
253 """Whee!"""
254 return baz
255
256 def test_contextmanager_attribs(self):
257 baz = self._create_contextmanager_attribs()
258 self.assertEqual(baz.__name__,'baz')
259 self.assertEqual(baz.foo, 'bar')
260
261 @support.requires_docstrings
262 def test_contextmanager_doc_attrib(self):
263 baz = self._create_contextmanager_attribs()
264 self.assertEqual(baz.__doc__, "Whee!")
265
266 @support.requires_docstrings
267 def test_instance_docstring_given_cm_docstring(self):
268 baz = self._create_contextmanager_attribs()(None)
269 self.assertEqual(baz.__doc__, "Whee!")
270
271 def test_keywords(self):
272 # Ensure no keyword arguments are inhibited
273 @contextmanager
274 def woohoo(self, func, args, kwds):
275 yield (self, func, args, kwds)
276 with woohoo(self=11, func=22, args=33, kwds=44) as target:
277 self.assertEqual(target, (11, 22, 33, 44))
278
279 def test_nokeepref(self):
280 class ESC[4;38;5;81mA:
281 pass
282
283 @contextmanager
284 def woohoo(a, b):
285 a = weakref.ref(a)
286 b = weakref.ref(b)
287 # Allow test to work with a non-refcounted GC
288 support.gc_collect()
289 self.assertIsNone(a())
290 self.assertIsNone(b())
291 yield
292
293 with woohoo(A(), b=A()):
294 pass
295
296 def test_param_errors(self):
297 @contextmanager
298 def woohoo(a, *, b):
299 yield
300
301 with self.assertRaises(TypeError):
302 woohoo()
303 with self.assertRaises(TypeError):
304 woohoo(3, 5)
305 with self.assertRaises(TypeError):
306 woohoo(b=3)
307
308 def test_recursive(self):
309 depth = 0
310 @contextmanager
311 def woohoo():
312 nonlocal depth
313 before = depth
314 depth += 1
315 yield
316 depth -= 1
317 self.assertEqual(depth, before)
318
319 @woohoo()
320 def recursive():
321 if depth < 10:
322 recursive()
323
324 recursive()
325 self.assertEqual(depth, 0)
326
327
328 class ESC[4;38;5;81mClosingTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
329
330 @support.requires_docstrings
331 def test_instance_docs(self):
332 # Issue 19330: ensure context manager instances have good docstrings
333 cm_docstring = closing.__doc__
334 obj = closing(None)
335 self.assertEqual(obj.__doc__, cm_docstring)
336
337 def test_closing(self):
338 state = []
339 class ESC[4;38;5;81mC:
340 def close(self):
341 state.append(1)
342 x = C()
343 self.assertEqual(state, [])
344 with closing(x) as y:
345 self.assertEqual(x, y)
346 self.assertEqual(state, [1])
347
348 def test_closing_error(self):
349 state = []
350 class ESC[4;38;5;81mC:
351 def close(self):
352 state.append(1)
353 x = C()
354 self.assertEqual(state, [])
355 with self.assertRaises(ZeroDivisionError):
356 with closing(x) as y:
357 self.assertEqual(x, y)
358 1 / 0
359 self.assertEqual(state, [1])
360
361
362 class ESC[4;38;5;81mNullcontextTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
363 def test_nullcontext(self):
364 class ESC[4;38;5;81mC:
365 pass
366 c = C()
367 with nullcontext(c) as c_in:
368 self.assertIs(c_in, c)
369
370
371 class ESC[4;38;5;81mFileContextTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
372
373 def testWithOpen(self):
374 tfn = tempfile.mktemp()
375 try:
376 f = None
377 with open(tfn, "w", encoding="utf-8") as f:
378 self.assertFalse(f.closed)
379 f.write("Booh\n")
380 self.assertTrue(f.closed)
381 f = None
382 with self.assertRaises(ZeroDivisionError):
383 with open(tfn, "r", encoding="utf-8") as f:
384 self.assertFalse(f.closed)
385 self.assertEqual(f.read(), "Booh\n")
386 1 / 0
387 self.assertTrue(f.closed)
388 finally:
389 os_helper.unlink(tfn)
390
391 class ESC[4;38;5;81mLockContextTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
392
393 def boilerPlate(self, lock, locked):
394 self.assertFalse(locked())
395 with lock:
396 self.assertTrue(locked())
397 self.assertFalse(locked())
398 with self.assertRaises(ZeroDivisionError):
399 with lock:
400 self.assertTrue(locked())
401 1 / 0
402 self.assertFalse(locked())
403
404 def testWithLock(self):
405 lock = threading.Lock()
406 self.boilerPlate(lock, lock.locked)
407
408 def testWithRLock(self):
409 lock = threading.RLock()
410 self.boilerPlate(lock, lock._is_owned)
411
412 def testWithCondition(self):
413 lock = threading.Condition()
414 def locked():
415 return lock._is_owned()
416 self.boilerPlate(lock, locked)
417
418 def testWithSemaphore(self):
419 lock = threading.Semaphore()
420 def locked():
421 if lock.acquire(False):
422 lock.release()
423 return False
424 else:
425 return True
426 self.boilerPlate(lock, locked)
427
428 def testWithBoundedSemaphore(self):
429 lock = threading.BoundedSemaphore()
430 def locked():
431 if lock.acquire(False):
432 lock.release()
433 return False
434 else:
435 return True
436 self.boilerPlate(lock, locked)
437
438
439 class ESC[4;38;5;81mmycontext(ESC[4;38;5;149mContextDecorator):
440 """Example decoration-compatible context manager for testing"""
441 started = False
442 exc = None
443 catch = False
444
445 def __enter__(self):
446 self.started = True
447 return self
448
449 def __exit__(self, *exc):
450 self.exc = exc
451 return self.catch
452
453
454 class ESC[4;38;5;81mTestContextDecorator(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
455
456 @support.requires_docstrings
457 def test_instance_docs(self):
458 # Issue 19330: ensure context manager instances have good docstrings
459 cm_docstring = mycontext.__doc__
460 obj = mycontext()
461 self.assertEqual(obj.__doc__, cm_docstring)
462
463 def test_contextdecorator(self):
464 context = mycontext()
465 with context as result:
466 self.assertIs(result, context)
467 self.assertTrue(context.started)
468
469 self.assertEqual(context.exc, (None, None, None))
470
471
472 def test_contextdecorator_with_exception(self):
473 context = mycontext()
474
475 with self.assertRaisesRegex(NameError, 'foo'):
476 with context:
477 raise NameError('foo')
478 self.assertIsNotNone(context.exc)
479 self.assertIs(context.exc[0], NameError)
480
481 context = mycontext()
482 context.catch = True
483 with context:
484 raise NameError('foo')
485 self.assertIsNotNone(context.exc)
486 self.assertIs(context.exc[0], NameError)
487
488
489 def test_decorator(self):
490 context = mycontext()
491
492 @context
493 def test():
494 self.assertIsNone(context.exc)
495 self.assertTrue(context.started)
496 test()
497 self.assertEqual(context.exc, (None, None, None))
498
499
500 def test_decorator_with_exception(self):
501 context = mycontext()
502
503 @context
504 def test():
505 self.assertIsNone(context.exc)
506 self.assertTrue(context.started)
507 raise NameError('foo')
508
509 with self.assertRaisesRegex(NameError, 'foo'):
510 test()
511 self.assertIsNotNone(context.exc)
512 self.assertIs(context.exc[0], NameError)
513
514
515 def test_decorating_method(self):
516 context = mycontext()
517
518 class ESC[4;38;5;81mTest(ESC[4;38;5;149mobject):
519
520 @context
521 def method(self, a, b, c=None):
522 self.a = a
523 self.b = b
524 self.c = c
525
526 # these tests are for argument passing when used as a decorator
527 test = Test()
528 test.method(1, 2)
529 self.assertEqual(test.a, 1)
530 self.assertEqual(test.b, 2)
531 self.assertEqual(test.c, None)
532
533 test = Test()
534 test.method('a', 'b', 'c')
535 self.assertEqual(test.a, 'a')
536 self.assertEqual(test.b, 'b')
537 self.assertEqual(test.c, 'c')
538
539 test = Test()
540 test.method(a=1, b=2)
541 self.assertEqual(test.a, 1)
542 self.assertEqual(test.b, 2)
543
544
545 def test_typo_enter(self):
546 class ESC[4;38;5;81mmycontext(ESC[4;38;5;149mContextDecorator):
547 def __unter__(self):
548 pass
549 def __exit__(self, *exc):
550 pass
551
552 with self.assertRaisesRegex(TypeError, 'the context manager'):
553 with mycontext():
554 pass
555
556
557 def test_typo_exit(self):
558 class ESC[4;38;5;81mmycontext(ESC[4;38;5;149mContextDecorator):
559 def __enter__(self):
560 pass
561 def __uxit__(self, *exc):
562 pass
563
564 with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
565 with mycontext():
566 pass
567
568
569 def test_contextdecorator_as_mixin(self):
570 class ESC[4;38;5;81msomecontext(ESC[4;38;5;149mobject):
571 started = False
572 exc = None
573
574 def __enter__(self):
575 self.started = True
576 return self
577
578 def __exit__(self, *exc):
579 self.exc = exc
580
581 class ESC[4;38;5;81mmycontext(ESC[4;38;5;149msomecontext, ESC[4;38;5;149mContextDecorator):
582 pass
583
584 context = mycontext()
585 @context
586 def test():
587 self.assertIsNone(context.exc)
588 self.assertTrue(context.started)
589 test()
590 self.assertEqual(context.exc, (None, None, None))
591
592
593 def test_contextmanager_as_decorator(self):
594 @contextmanager
595 def woohoo(y):
596 state.append(y)
597 yield
598 state.append(999)
599
600 state = []
601 @woohoo(1)
602 def test(x):
603 self.assertEqual(state, [1])
604 state.append(x)
605 test('something')
606 self.assertEqual(state, [1, 'something', 999])
607
608 # Issue #11647: Ensure the decorated function is 'reusable'
609 state = []
610 test('something else')
611 self.assertEqual(state, [1, 'something else', 999])
612
613
614 class ESC[4;38;5;81mTestBaseExitStack:
615 exit_stack = None
616
617 @support.requires_docstrings
618 def test_instance_docs(self):
619 # Issue 19330: ensure context manager instances have good docstrings
620 cm_docstring = self.exit_stack.__doc__
621 obj = self.exit_stack()
622 self.assertEqual(obj.__doc__, cm_docstring)
623
624 def test_no_resources(self):
625 with self.exit_stack():
626 pass
627
628 def test_callback(self):
629 expected = [
630 ((), {}),
631 ((1,), {}),
632 ((1,2), {}),
633 ((), dict(example=1)),
634 ((1,), dict(example=1)),
635 ((1,2), dict(example=1)),
636 ((1,2), dict(self=3, callback=4)),
637 ]
638 result = []
639 def _exit(*args, **kwds):
640 """Test metadata propagation"""
641 result.append((args, kwds))
642 with self.exit_stack() as stack:
643 for args, kwds in reversed(expected):
644 if args and kwds:
645 f = stack.callback(_exit, *args, **kwds)
646 elif args:
647 f = stack.callback(_exit, *args)
648 elif kwds:
649 f = stack.callback(_exit, **kwds)
650 else:
651 f = stack.callback(_exit)
652 self.assertIs(f, _exit)
653 for wrapper in stack._exit_callbacks:
654 self.assertIs(wrapper[1].__wrapped__, _exit)
655 self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
656 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
657 self.assertEqual(result, expected)
658
659 result = []
660 with self.exit_stack() as stack:
661 with self.assertRaises(TypeError):
662 stack.callback(arg=1)
663 with self.assertRaises(TypeError):
664 self.exit_stack.callback(arg=2)
665 with self.assertRaises(TypeError):
666 stack.callback(callback=_exit, arg=3)
667 self.assertEqual(result, [])
668
669 def test_push(self):
670 exc_raised = ZeroDivisionError
671 def _expect_exc(exc_type, exc, exc_tb):
672 self.assertIs(exc_type, exc_raised)
673 def _suppress_exc(*exc_details):
674 return True
675 def _expect_ok(exc_type, exc, exc_tb):
676 self.assertIsNone(exc_type)
677 self.assertIsNone(exc)
678 self.assertIsNone(exc_tb)
679 class ESC[4;38;5;81mExitCM(ESC[4;38;5;149mobject):
680 def __init__(self, check_exc):
681 self.check_exc = check_exc
682 def __enter__(self):
683 self.fail("Should not be called!")
684 def __exit__(self, *exc_details):
685 self.check_exc(*exc_details)
686 with self.exit_stack() as stack:
687 stack.push(_expect_ok)
688 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
689 cm = ExitCM(_expect_ok)
690 stack.push(cm)
691 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
692 stack.push(_suppress_exc)
693 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
694 cm = ExitCM(_expect_exc)
695 stack.push(cm)
696 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
697 stack.push(_expect_exc)
698 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
699 stack.push(_expect_exc)
700 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
701 1/0
702
703 def test_enter_context(self):
704 class ESC[4;38;5;81mTestCM(ESC[4;38;5;149mobject):
705 def __enter__(self):
706 result.append(1)
707 def __exit__(self, *exc_details):
708 result.append(3)
709
710 result = []
711 cm = TestCM()
712 with self.exit_stack() as stack:
713 @stack.callback # Registered first => cleaned up last
714 def _exit():
715 result.append(4)
716 self.assertIsNotNone(_exit)
717 stack.enter_context(cm)
718 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
719 result.append(2)
720 self.assertEqual(result, [1, 2, 3, 4])
721
722 def test_enter_context_errors(self):
723 class ESC[4;38;5;81mLacksEnterAndExit:
724 pass
725 class ESC[4;38;5;81mLacksEnter:
726 def __exit__(self, *exc_info):
727 pass
728 class ESC[4;38;5;81mLacksExit:
729 def __enter__(self):
730 pass
731
732 with self.exit_stack() as stack:
733 with self.assertRaisesRegex(TypeError, 'the context manager'):
734 stack.enter_context(LacksEnterAndExit())
735 with self.assertRaisesRegex(TypeError, 'the context manager'):
736 stack.enter_context(LacksEnter())
737 with self.assertRaisesRegex(TypeError, 'the context manager'):
738 stack.enter_context(LacksExit())
739 self.assertFalse(stack._exit_callbacks)
740
741 def test_close(self):
742 result = []
743 with self.exit_stack() as stack:
744 @stack.callback
745 def _exit():
746 result.append(1)
747 self.assertIsNotNone(_exit)
748 stack.close()
749 result.append(2)
750 self.assertEqual(result, [1, 2])
751
752 def test_pop_all(self):
753 result = []
754 with self.exit_stack() as stack:
755 @stack.callback
756 def _exit():
757 result.append(3)
758 self.assertIsNotNone(_exit)
759 new_stack = stack.pop_all()
760 result.append(1)
761 result.append(2)
762 new_stack.close()
763 self.assertEqual(result, [1, 2, 3])
764
765 def test_exit_raise(self):
766 with self.assertRaises(ZeroDivisionError):
767 with self.exit_stack() as stack:
768 stack.push(lambda *exc: False)
769 1/0
770
771 def test_exit_suppress(self):
772 with self.exit_stack() as stack:
773 stack.push(lambda *exc: True)
774 1/0
775
776 def test_exit_exception_traceback(self):
777 # This test captures the current behavior of ExitStack so that we know
778 # if we ever unintendedly change it. It is not a statement of what the
779 # desired behavior is (for instance, we may want to remove some of the
780 # internal contextlib frames).
781
782 def raise_exc(exc):
783 raise exc
784
785 try:
786 with self.exit_stack() as stack:
787 stack.callback(raise_exc, ValueError)
788 1/0
789 except ValueError as e:
790 exc = e
791
792 self.assertIsInstance(exc, ValueError)
793 ve_frames = traceback.extract_tb(exc.__traceback__)
794 expected = \
795 [('test_exit_exception_traceback', 'with self.exit_stack() as stack:')] + \
796 self.callback_error_internal_frames + \
797 [('_exit_wrapper', 'callback(*args, **kwds)'),
798 ('raise_exc', 'raise exc')]
799
800 self.assertEqual(
801 [(f.name, f.line) for f in ve_frames], expected)
802
803 self.assertIsInstance(exc.__context__, ZeroDivisionError)
804 zde_frames = traceback.extract_tb(exc.__context__.__traceback__)
805 self.assertEqual([(f.name, f.line) for f in zde_frames],
806 [('test_exit_exception_traceback', '1/0')])
807
808 def test_exit_exception_chaining_reference(self):
809 # Sanity check to make sure that ExitStack chaining matches
810 # actual nested with statements
811 class ESC[4;38;5;81mRaiseExc:
812 def __init__(self, exc):
813 self.exc = exc
814 def __enter__(self):
815 return self
816 def __exit__(self, *exc_details):
817 raise self.exc
818
819 class ESC[4;38;5;81mRaiseExcWithContext:
820 def __init__(self, outer, inner):
821 self.outer = outer
822 self.inner = inner
823 def __enter__(self):
824 return self
825 def __exit__(self, *exc_details):
826 try:
827 raise self.inner
828 except:
829 raise self.outer
830
831 class ESC[4;38;5;81mSuppressExc:
832 def __enter__(self):
833 return self
834 def __exit__(self, *exc_details):
835 type(self).saved_details = exc_details
836 return True
837
838 try:
839 with RaiseExc(IndexError):
840 with RaiseExcWithContext(KeyError, AttributeError):
841 with SuppressExc():
842 with RaiseExc(ValueError):
843 1 / 0
844 except IndexError as exc:
845 self.assertIsInstance(exc.__context__, KeyError)
846 self.assertIsInstance(exc.__context__.__context__, AttributeError)
847 # Inner exceptions were suppressed
848 self.assertIsNone(exc.__context__.__context__.__context__)
849 else:
850 self.fail("Expected IndexError, but no exception was raised")
851 # Check the inner exceptions
852 inner_exc = SuppressExc.saved_details[1]
853 self.assertIsInstance(inner_exc, ValueError)
854 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
855
856 def test_exit_exception_chaining(self):
857 # Ensure exception chaining matches the reference behaviour
858 def raise_exc(exc):
859 raise exc
860
861 saved_details = None
862 def suppress_exc(*exc_details):
863 nonlocal saved_details
864 saved_details = exc_details
865 return True
866
867 try:
868 with self.exit_stack() as stack:
869 stack.callback(raise_exc, IndexError)
870 stack.callback(raise_exc, KeyError)
871 stack.callback(raise_exc, AttributeError)
872 stack.push(suppress_exc)
873 stack.callback(raise_exc, ValueError)
874 1 / 0
875 except IndexError as exc:
876 self.assertIsInstance(exc.__context__, KeyError)
877 self.assertIsInstance(exc.__context__.__context__, AttributeError)
878 # Inner exceptions were suppressed
879 self.assertIsNone(exc.__context__.__context__.__context__)
880 else:
881 self.fail("Expected IndexError, but no exception was raised")
882 # Check the inner exceptions
883 inner_exc = saved_details[1]
884 self.assertIsInstance(inner_exc, ValueError)
885 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
886
887 def test_exit_exception_explicit_none_context(self):
888 # Ensure ExitStack chaining matches actual nested `with` statements
889 # regarding explicit __context__ = None.
890
891 class ESC[4;38;5;81mMyException(ESC[4;38;5;149mException):
892 pass
893
894 @contextmanager
895 def my_cm():
896 try:
897 yield
898 except BaseException:
899 exc = MyException()
900 try:
901 raise exc
902 finally:
903 exc.__context__ = None
904
905 @contextmanager
906 def my_cm_with_exit_stack():
907 with self.exit_stack() as stack:
908 stack.enter_context(my_cm())
909 yield stack
910
911 for cm in (my_cm, my_cm_with_exit_stack):
912 with self.subTest():
913 try:
914 with cm():
915 raise IndexError()
916 except MyException as exc:
917 self.assertIsNone(exc.__context__)
918 else:
919 self.fail("Expected IndexError, but no exception was raised")
920
921 def test_exit_exception_non_suppressing(self):
922 # http://bugs.python.org/issue19092
923 def raise_exc(exc):
924 raise exc
925
926 def suppress_exc(*exc_details):
927 return True
928
929 try:
930 with self.exit_stack() as stack:
931 stack.callback(lambda: None)
932 stack.callback(raise_exc, IndexError)
933 except Exception as exc:
934 self.assertIsInstance(exc, IndexError)
935 else:
936 self.fail("Expected IndexError, but no exception was raised")
937
938 try:
939 with self.exit_stack() as stack:
940 stack.callback(raise_exc, KeyError)
941 stack.push(suppress_exc)
942 stack.callback(raise_exc, IndexError)
943 except Exception as exc:
944 self.assertIsInstance(exc, KeyError)
945 else:
946 self.fail("Expected KeyError, but no exception was raised")
947
948 def test_exit_exception_with_correct_context(self):
949 # http://bugs.python.org/issue20317
950 @contextmanager
951 def gets_the_context_right(exc):
952 try:
953 yield
954 finally:
955 raise exc
956
957 exc1 = Exception(1)
958 exc2 = Exception(2)
959 exc3 = Exception(3)
960 exc4 = Exception(4)
961
962 # The contextmanager already fixes the context, so prior to the
963 # fix, ExitStack would try to fix it *again* and get into an
964 # infinite self-referential loop
965 try:
966 with self.exit_stack() as stack:
967 stack.enter_context(gets_the_context_right(exc4))
968 stack.enter_context(gets_the_context_right(exc3))
969 stack.enter_context(gets_the_context_right(exc2))
970 raise exc1
971 except Exception as exc:
972 self.assertIs(exc, exc4)
973 self.assertIs(exc.__context__, exc3)
974 self.assertIs(exc.__context__.__context__, exc2)
975 self.assertIs(exc.__context__.__context__.__context__, exc1)
976 self.assertIsNone(
977 exc.__context__.__context__.__context__.__context__)
978
979 def test_exit_exception_with_existing_context(self):
980 # Addresses a lack of test coverage discovered after checking in a
981 # fix for issue 20317 that still contained debugging code.
982 def raise_nested(inner_exc, outer_exc):
983 try:
984 raise inner_exc
985 finally:
986 raise outer_exc
987 exc1 = Exception(1)
988 exc2 = Exception(2)
989 exc3 = Exception(3)
990 exc4 = Exception(4)
991 exc5 = Exception(5)
992 try:
993 with self.exit_stack() as stack:
994 stack.callback(raise_nested, exc4, exc5)
995 stack.callback(raise_nested, exc2, exc3)
996 raise exc1
997 except Exception as exc:
998 self.assertIs(exc, exc5)
999 self.assertIs(exc.__context__, exc4)
1000 self.assertIs(exc.__context__.__context__, exc3)
1001 self.assertIs(exc.__context__.__context__.__context__, exc2)
1002 self.assertIs(
1003 exc.__context__.__context__.__context__.__context__, exc1)
1004 self.assertIsNone(
1005 exc.__context__.__context__.__context__.__context__.__context__)
1006
1007 def test_body_exception_suppress(self):
1008 def suppress_exc(*exc_details):
1009 return True
1010 try:
1011 with self.exit_stack() as stack:
1012 stack.push(suppress_exc)
1013 1/0
1014 except IndexError as exc:
1015 self.fail("Expected no exception, got IndexError")
1016
1017 def test_exit_exception_chaining_suppress(self):
1018 with self.exit_stack() as stack:
1019 stack.push(lambda *exc: True)
1020 stack.push(lambda *exc: 1/0)
1021 stack.push(lambda *exc: {}[1])
1022
1023 def test_excessive_nesting(self):
1024 # The original implementation would die with RecursionError here
1025 with self.exit_stack() as stack:
1026 for i in range(10000):
1027 stack.callback(int)
1028
1029 def test_instance_bypass(self):
1030 class ESC[4;38;5;81mExample(ESC[4;38;5;149mobject): pass
1031 cm = Example()
1032 cm.__enter__ = object()
1033 cm.__exit__ = object()
1034 stack = self.exit_stack()
1035 with self.assertRaisesRegex(TypeError, 'the context manager'):
1036 stack.enter_context(cm)
1037 stack.push(cm)
1038 self.assertIs(stack._exit_callbacks[-1][1], cm)
1039
1040 def test_dont_reraise_RuntimeError(self):
1041 # https://bugs.python.org/issue27122
1042 class ESC[4;38;5;81mUniqueException(ESC[4;38;5;149mException): pass
1043 class ESC[4;38;5;81mUniqueRuntimeError(ESC[4;38;5;149mRuntimeError): pass
1044
1045 @contextmanager
1046 def second():
1047 try:
1048 yield 1
1049 except Exception as exc:
1050 raise UniqueException("new exception") from exc
1051
1052 @contextmanager
1053 def first():
1054 try:
1055 yield 1
1056 except Exception as exc:
1057 raise exc
1058
1059 # The UniqueRuntimeError should be caught by second()'s exception
1060 # handler which chain raised a new UniqueException.
1061 with self.assertRaises(UniqueException) as err_ctx:
1062 with self.exit_stack() as es_ctx:
1063 es_ctx.enter_context(second())
1064 es_ctx.enter_context(first())
1065 raise UniqueRuntimeError("please no infinite loop.")
1066
1067 exc = err_ctx.exception
1068 self.assertIsInstance(exc, UniqueException)
1069 self.assertIsInstance(exc.__context__, UniqueRuntimeError)
1070 self.assertIsNone(exc.__context__.__context__)
1071 self.assertIsNone(exc.__context__.__cause__)
1072 self.assertIs(exc.__cause__, exc.__context__)
1073
1074
1075 class ESC[4;38;5;81mTestExitStack(ESC[4;38;5;149mTestBaseExitStack, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
1076 exit_stack = ExitStack
1077 callback_error_internal_frames = [
1078 ('__exit__', 'raise exc_details[1]'),
1079 ('__exit__', 'if cb(*exc_details):'),
1080 ]
1081
1082
1083 class ESC[4;38;5;81mTestRedirectStream:
1084
1085 redirect_stream = None
1086 orig_stream = None
1087
1088 @support.requires_docstrings
1089 def test_instance_docs(self):
1090 # Issue 19330: ensure context manager instances have good docstrings
1091 cm_docstring = self.redirect_stream.__doc__
1092 obj = self.redirect_stream(None)
1093 self.assertEqual(obj.__doc__, cm_docstring)
1094
1095 def test_no_redirect_in_init(self):
1096 orig_stdout = getattr(sys, self.orig_stream)
1097 self.redirect_stream(None)
1098 self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1099
1100 def test_redirect_to_string_io(self):
1101 f = io.StringIO()
1102 msg = "Consider an API like help(), which prints directly to stdout"
1103 orig_stdout = getattr(sys, self.orig_stream)
1104 with self.redirect_stream(f):
1105 print(msg, file=getattr(sys, self.orig_stream))
1106 self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1107 s = f.getvalue().strip()
1108 self.assertEqual(s, msg)
1109
1110 def test_enter_result_is_target(self):
1111 f = io.StringIO()
1112 with self.redirect_stream(f) as enter_result:
1113 self.assertIs(enter_result, f)
1114
1115 def test_cm_is_reusable(self):
1116 f = io.StringIO()
1117 write_to_f = self.redirect_stream(f)
1118 orig_stdout = getattr(sys, self.orig_stream)
1119 with write_to_f:
1120 print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1121 with write_to_f:
1122 print("World!", file=getattr(sys, self.orig_stream))
1123 self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1124 s = f.getvalue()
1125 self.assertEqual(s, "Hello World!\n")
1126
1127 def test_cm_is_reentrant(self):
1128 f = io.StringIO()
1129 write_to_f = self.redirect_stream(f)
1130 orig_stdout = getattr(sys, self.orig_stream)
1131 with write_to_f:
1132 print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1133 with write_to_f:
1134 print("World!", file=getattr(sys, self.orig_stream))
1135 self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1136 s = f.getvalue()
1137 self.assertEqual(s, "Hello World!\n")
1138
1139
1140 class ESC[4;38;5;81mTestRedirectStdout(ESC[4;38;5;149mTestRedirectStream, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
1141
1142 redirect_stream = redirect_stdout
1143 orig_stream = "stdout"
1144
1145
1146 class ESC[4;38;5;81mTestRedirectStderr(ESC[4;38;5;149mTestRedirectStream, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
1147
1148 redirect_stream = redirect_stderr
1149 orig_stream = "stderr"
1150
1151
1152 class ESC[4;38;5;81mTestSuppress(ESC[4;38;5;149mExceptionIsLikeMixin, ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
1153
1154 @support.requires_docstrings
1155 def test_instance_docs(self):
1156 # Issue 19330: ensure context manager instances have good docstrings
1157 cm_docstring = suppress.__doc__
1158 obj = suppress()
1159 self.assertEqual(obj.__doc__, cm_docstring)
1160
1161 def test_no_result_from_enter(self):
1162 with suppress(ValueError) as enter_result:
1163 self.assertIsNone(enter_result)
1164
1165 def test_no_exception(self):
1166 with suppress(ValueError):
1167 self.assertEqual(pow(2, 5), 32)
1168
1169 def test_exact_exception(self):
1170 with suppress(TypeError):
1171 len(5)
1172
1173 def test_exception_hierarchy(self):
1174 with suppress(LookupError):
1175 'Hello'[50]
1176
1177 def test_other_exception(self):
1178 with self.assertRaises(ZeroDivisionError):
1179 with suppress(TypeError):
1180 1/0
1181
1182 def test_no_args(self):
1183 with self.assertRaises(ZeroDivisionError):
1184 with suppress():
1185 1/0
1186
1187 def test_multiple_exception_args(self):
1188 with suppress(ZeroDivisionError, TypeError):
1189 1/0
1190 with suppress(ZeroDivisionError, TypeError):
1191 len(5)
1192
1193 def test_cm_is_reentrant(self):
1194 ignore_exceptions = suppress(Exception)
1195 with ignore_exceptions:
1196 pass
1197 with ignore_exceptions:
1198 len(5)
1199 with ignore_exceptions:
1200 with ignore_exceptions: # Check nested usage
1201 len(5)
1202 outer_continued = True
1203 1/0
1204 self.assertTrue(outer_continued)
1205
1206 def test_exception_groups(self):
1207 eg_ve = lambda: ExceptionGroup(
1208 "EG with ValueErrors only",
1209 [ValueError("ve1"), ValueError("ve2"), ValueError("ve3")],
1210 )
1211 eg_all = lambda: ExceptionGroup(
1212 "EG with many types of exceptions",
1213 [ValueError("ve1"), KeyError("ke1"), ValueError("ve2"), KeyError("ke2")],
1214 )
1215 with suppress(ValueError):
1216 raise eg_ve()
1217 with suppress(ValueError, KeyError):
1218 raise eg_all()
1219 with self.assertRaises(ExceptionGroup) as eg1:
1220 with suppress(ValueError):
1221 raise eg_all()
1222 self.assertExceptionIsLike(
1223 eg1.exception,
1224 ExceptionGroup(
1225 "EG with many types of exceptions",
1226 [KeyError("ke1"), KeyError("ke2")],
1227 ),
1228 )
1229
1230
1231 class ESC[4;38;5;81mTestChdir(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
1232 def make_relative_path(self, *parts):
1233 return os.path.join(
1234 os.path.dirname(os.path.realpath(__file__)),
1235 *parts,
1236 )
1237
1238 def test_simple(self):
1239 old_cwd = os.getcwd()
1240 target = self.make_relative_path('data')
1241 self.assertNotEqual(old_cwd, target)
1242
1243 with chdir(target):
1244 self.assertEqual(os.getcwd(), target)
1245 self.assertEqual(os.getcwd(), old_cwd)
1246
1247 def test_reentrant(self):
1248 old_cwd = os.getcwd()
1249 target1 = self.make_relative_path('data')
1250 target2 = self.make_relative_path('ziptestdata')
1251 self.assertNotIn(old_cwd, (target1, target2))
1252 chdir1, chdir2 = chdir(target1), chdir(target2)
1253
1254 with chdir1:
1255 self.assertEqual(os.getcwd(), target1)
1256 with chdir2:
1257 self.assertEqual(os.getcwd(), target2)
1258 with chdir1:
1259 self.assertEqual(os.getcwd(), target1)
1260 self.assertEqual(os.getcwd(), target2)
1261 self.assertEqual(os.getcwd(), target1)
1262 self.assertEqual(os.getcwd(), old_cwd)
1263
1264 def test_exception(self):
1265 old_cwd = os.getcwd()
1266 target = self.make_relative_path('data')
1267 self.assertNotEqual(old_cwd, target)
1268
1269 try:
1270 with chdir(target):
1271 self.assertEqual(os.getcwd(), target)
1272 raise RuntimeError("boom")
1273 except RuntimeError as re:
1274 self.assertEqual(str(re), "boom")
1275 self.assertEqual(os.getcwd(), old_cwd)
1276
1277
1278 if __name__ == "__main__":
1279 unittest.main()