|
1 """Unit tests for collections.defaultdict.""" |
|
2 |
|
3 import os |
|
4 import copy |
|
5 import tempfile |
|
6 import unittest |
|
7 from test import test_support |
|
8 |
|
9 from collections import defaultdict |
|
10 |
|
11 def foobar(): |
|
12 return list |
|
13 |
|
14 class TestDefaultDict(unittest.TestCase): |
|
15 |
|
16 def test_basic(self): |
|
17 d1 = defaultdict() |
|
18 self.assertEqual(d1.default_factory, None) |
|
19 d1.default_factory = list |
|
20 d1[12].append(42) |
|
21 self.assertEqual(d1, {12: [42]}) |
|
22 d1[12].append(24) |
|
23 self.assertEqual(d1, {12: [42, 24]}) |
|
24 d1[13] |
|
25 d1[14] |
|
26 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) |
|
27 self.assert_(d1[12] is not d1[13] is not d1[14]) |
|
28 d2 = defaultdict(list, foo=1, bar=2) |
|
29 self.assertEqual(d2.default_factory, list) |
|
30 self.assertEqual(d2, {"foo": 1, "bar": 2}) |
|
31 self.assertEqual(d2["foo"], 1) |
|
32 self.assertEqual(d2["bar"], 2) |
|
33 self.assertEqual(d2[42], []) |
|
34 self.assert_("foo" in d2) |
|
35 self.assert_("foo" in d2.keys()) |
|
36 self.assert_("bar" in d2) |
|
37 self.assert_("bar" in d2.keys()) |
|
38 self.assert_(42 in d2) |
|
39 self.assert_(42 in d2.keys()) |
|
40 self.assert_(12 not in d2) |
|
41 self.assert_(12 not in d2.keys()) |
|
42 d2.default_factory = None |
|
43 self.assertEqual(d2.default_factory, None) |
|
44 try: |
|
45 d2[15] |
|
46 except KeyError, err: |
|
47 self.assertEqual(err.args, (15,)) |
|
48 else: |
|
49 self.fail("d2[15] didn't raise KeyError") |
|
50 self.assertRaises(TypeError, defaultdict, 1) |
|
51 |
|
52 def test_missing(self): |
|
53 d1 = defaultdict() |
|
54 self.assertRaises(KeyError, d1.__missing__, 42) |
|
55 d1.default_factory = list |
|
56 self.assertEqual(d1.__missing__(42), []) |
|
57 |
|
58 def test_repr(self): |
|
59 d1 = defaultdict() |
|
60 self.assertEqual(d1.default_factory, None) |
|
61 self.assertEqual(repr(d1), "defaultdict(None, {})") |
|
62 d1[11] = 41 |
|
63 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") |
|
64 d2 = defaultdict(int) |
|
65 self.assertEqual(d2.default_factory, int) |
|
66 d2[12] = 42 |
|
67 self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})") |
|
68 def foo(): return 43 |
|
69 d3 = defaultdict(foo) |
|
70 self.assert_(d3.default_factory is foo) |
|
71 d3[13] |
|
72 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) |
|
73 |
|
74 def test_print(self): |
|
75 d1 = defaultdict() |
|
76 def foo(): return 42 |
|
77 d2 = defaultdict(foo, {1: 2}) |
|
78 # NOTE: We can't use tempfile.[Named]TemporaryFile since this |
|
79 # code must exercise the tp_print C code, which only gets |
|
80 # invoked for *real* files. |
|
81 tfn = tempfile.mktemp() |
|
82 try: |
|
83 f = open(tfn, "w+") |
|
84 try: |
|
85 print >>f, d1 |
|
86 print >>f, d2 |
|
87 f.seek(0) |
|
88 self.assertEqual(f.readline(), repr(d1) + "\n") |
|
89 self.assertEqual(f.readline(), repr(d2) + "\n") |
|
90 finally: |
|
91 f.close() |
|
92 finally: |
|
93 os.remove(tfn) |
|
94 |
|
95 def test_copy(self): |
|
96 d1 = defaultdict() |
|
97 d2 = d1.copy() |
|
98 self.assertEqual(type(d2), defaultdict) |
|
99 self.assertEqual(d2.default_factory, None) |
|
100 self.assertEqual(d2, {}) |
|
101 d1.default_factory = list |
|
102 d3 = d1.copy() |
|
103 self.assertEqual(type(d3), defaultdict) |
|
104 self.assertEqual(d3.default_factory, list) |
|
105 self.assertEqual(d3, {}) |
|
106 d1[42] |
|
107 d4 = d1.copy() |
|
108 self.assertEqual(type(d4), defaultdict) |
|
109 self.assertEqual(d4.default_factory, list) |
|
110 self.assertEqual(d4, {42: []}) |
|
111 d4[12] |
|
112 self.assertEqual(d4, {42: [], 12: []}) |
|
113 |
|
114 def test_shallow_copy(self): |
|
115 d1 = defaultdict(foobar, {1: 1}) |
|
116 d2 = copy.copy(d1) |
|
117 self.assertEqual(d2.default_factory, foobar) |
|
118 self.assertEqual(d2, d1) |
|
119 d1.default_factory = list |
|
120 d2 = copy.copy(d1) |
|
121 self.assertEqual(d2.default_factory, list) |
|
122 self.assertEqual(d2, d1) |
|
123 |
|
124 def test_deep_copy(self): |
|
125 d1 = defaultdict(foobar, {1: [1]}) |
|
126 d2 = copy.deepcopy(d1) |
|
127 self.assertEqual(d2.default_factory, foobar) |
|
128 self.assertEqual(d2, d1) |
|
129 self.assert_(d1[1] is not d2[1]) |
|
130 d1.default_factory = list |
|
131 d2 = copy.deepcopy(d1) |
|
132 self.assertEqual(d2.default_factory, list) |
|
133 self.assertEqual(d2, d1) |
|
134 |
|
135 def test_keyerror_without_factory(self): |
|
136 d1 = defaultdict() |
|
137 try: |
|
138 d1[(1,)] |
|
139 except KeyError, err: |
|
140 self.assertEqual(err.args[0], (1,)) |
|
141 else: |
|
142 self.fail("expected KeyError") |
|
143 |
|
144 def test_recursive_repr(self): |
|
145 # Issue2045: stack overflow when default_factory is a bound method |
|
146 class sub(defaultdict): |
|
147 def __init__(self): |
|
148 self.default_factory = self._factory |
|
149 def _factory(self): |
|
150 return [] |
|
151 d = sub() |
|
152 self.assert_(repr(d).startswith( |
|
153 "defaultdict(<bound method sub._factory of defaultdict(...")) |
|
154 |
|
155 # NOTE: printing a subclass of a builtin type does not call its |
|
156 # tp_print slot. So this part is essentially the same test as above. |
|
157 tfn = tempfile.mktemp() |
|
158 try: |
|
159 f = open(tfn, "w+") |
|
160 try: |
|
161 print >>f, d |
|
162 finally: |
|
163 f.close() |
|
164 finally: |
|
165 os.remove(tfn) |
|
166 |
|
167 |
|
168 def test_main(): |
|
169 test_support.run_unittest(TestDefaultDict) |
|
170 |
|
171 if __name__ == "__main__": |
|
172 test_main() |