1 import os
2 import sys
3 import shutil
4 import string
5 import random
6 import tempfile
7 import unittest
8
9 from importlib.util import cache_from_source
10 from test.support.os_helper import create_empty_file
11
12 class ESC[4;38;5;81mTestImport(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
13
14 def __init__(self, *args, **kw):
15 self.package_name = 'PACKAGE_'
16 while self.package_name in sys.modules:
17 self.package_name += random.choice(string.ascii_letters)
18 self.module_name = self.package_name + '.foo'
19 unittest.TestCase.__init__(self, *args, **kw)
20
21 def remove_modules(self):
22 for module_name in (self.package_name, self.module_name):
23 if module_name in sys.modules:
24 del sys.modules[module_name]
25
26 def setUp(self):
27 self.test_dir = tempfile.mkdtemp()
28 sys.path.append(self.test_dir)
29 self.package_dir = os.path.join(self.test_dir,
30 self.package_name)
31 os.mkdir(self.package_dir)
32 create_empty_file(os.path.join(self.package_dir, '__init__.py'))
33 self.module_path = os.path.join(self.package_dir, 'foo.py')
34
35 def tearDown(self):
36 shutil.rmtree(self.test_dir)
37 self.assertNotEqual(sys.path.count(self.test_dir), 0)
38 sys.path.remove(self.test_dir)
39 self.remove_modules()
40
41 def rewrite_file(self, contents):
42 compiled_path = cache_from_source(self.module_path)
43 if os.path.exists(compiled_path):
44 os.remove(compiled_path)
45 with open(self.module_path, 'w', encoding='utf-8') as f:
46 f.write(contents)
47
48 def test_package_import__semantics(self):
49
50 # Generate a couple of broken modules to try importing.
51
52 # ...try loading the module when there's a SyntaxError
53 self.rewrite_file('for')
54 try: __import__(self.module_name)
55 except SyntaxError: pass
56 else: raise RuntimeError('Failed to induce SyntaxError') # self.fail()?
57 self.assertNotIn(self.module_name, sys.modules)
58 self.assertFalse(hasattr(sys.modules[self.package_name], 'foo'))
59
60 # ...make up a variable name that isn't bound in __builtins__
61 var = 'a'
62 while var in dir(__builtins__):
63 var += random.choice(string.ascii_letters)
64
65 # ...make a module that just contains that
66 self.rewrite_file(var)
67
68 try: __import__(self.module_name)
69 except NameError: pass
70 else: raise RuntimeError('Failed to induce NameError.')
71
72 # ...now change the module so that the NameError doesn't
73 # happen
74 self.rewrite_file('%s = 1' % var)
75 module = __import__(self.module_name).foo
76 self.assertEqual(getattr(module, var), 1)
77
78
79 if __name__ == "__main__":
80 unittest.main()