|
1 import os, sys, string, random, tempfile, unittest |
|
2 |
|
3 from test.test_support import run_unittest |
|
4 |
|
5 class TestImport(unittest.TestCase): |
|
6 |
|
7 def __init__(self, *args, **kw): |
|
8 self.package_name = 'PACKAGE_' |
|
9 while sys.modules.has_key(self.package_name): |
|
10 self.package_name += random.choose(string.letters) |
|
11 self.module_name = self.package_name + '.foo' |
|
12 unittest.TestCase.__init__(self, *args, **kw) |
|
13 |
|
14 def remove_modules(self): |
|
15 for module_name in (self.package_name, self.module_name): |
|
16 if sys.modules.has_key(module_name): |
|
17 del sys.modules[module_name] |
|
18 |
|
19 def setUp(self): |
|
20 self.test_dir = tempfile.mkdtemp() |
|
21 sys.path.append(self.test_dir) |
|
22 self.package_dir = os.path.join(self.test_dir, |
|
23 self.package_name) |
|
24 os.mkdir(self.package_dir) |
|
25 open(os.path.join(self.package_dir, '__init__'+os.extsep+'py'), 'w') |
|
26 self.module_path = os.path.join(self.package_dir, 'foo'+os.extsep+'py') |
|
27 |
|
28 def tearDown(self): |
|
29 for file in os.listdir(self.package_dir): |
|
30 os.remove(os.path.join(self.package_dir, file)) |
|
31 os.rmdir(self.package_dir) |
|
32 os.rmdir(self.test_dir) |
|
33 self.assertNotEqual(sys.path.count(self.test_dir), 0) |
|
34 sys.path.remove(self.test_dir) |
|
35 self.remove_modules() |
|
36 |
|
37 def rewrite_file(self, contents): |
|
38 for extension in "co": |
|
39 compiled_path = self.module_path + extension |
|
40 if os.path.exists(compiled_path): |
|
41 os.remove(compiled_path) |
|
42 f = open(self.module_path, 'w') |
|
43 f.write(contents) |
|
44 f.close() |
|
45 |
|
46 def test_package_import__semantics(self): |
|
47 |
|
48 # Generate a couple of broken modules to try importing. |
|
49 |
|
50 # ...try loading the module when there's a SyntaxError |
|
51 self.rewrite_file('for') |
|
52 try: __import__(self.module_name) |
|
53 except SyntaxError: pass |
|
54 else: raise RuntimeError, 'Failed to induce SyntaxError' |
|
55 self.assert_(not sys.modules.has_key(self.module_name) and |
|
56 not hasattr(sys.modules[self.package_name], 'foo')) |
|
57 |
|
58 # ...make up a variable name that isn't bound in __builtins__ |
|
59 var = 'a' |
|
60 while var in dir(__builtins__): |
|
61 var += random.choose(string.letters) |
|
62 |
|
63 # ...make a module that just contains that |
|
64 self.rewrite_file(var) |
|
65 |
|
66 try: __import__(self.module_name) |
|
67 except NameError: pass |
|
68 else: raise RuntimeError, 'Failed to induce NameError.' |
|
69 |
|
70 # ...now change the module so that the NameError doesn't |
|
71 # happen |
|
72 self.rewrite_file('%s = 1' % var) |
|
73 module = __import__(self.module_name).foo |
|
74 self.assertEqual(getattr(module, var), 1) |
|
75 |
|
76 |
|
77 def test_main(): |
|
78 run_unittest(TestImport) |
|
79 |
|
80 |
|
81 if __name__ == "__main__": |
|
82 test_main() |