|
1 """Unittests for heapq.""" |
|
2 |
|
3 from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest |
|
4 import random |
|
5 import unittest |
|
6 from test import test_support |
|
7 import sys |
|
8 |
|
9 |
|
10 def heapiter(heap): |
|
11 # An iterator returning a heap's elements, smallest-first. |
|
12 try: |
|
13 while 1: |
|
14 yield heappop(heap) |
|
15 except IndexError: |
|
16 pass |
|
17 |
|
18 class TestHeap(unittest.TestCase): |
|
19 |
|
20 def test_push_pop(self): |
|
21 # 1) Push 256 random numbers and pop them off, verifying all's OK. |
|
22 heap = [] |
|
23 data = [] |
|
24 self.check_invariant(heap) |
|
25 for i in range(256): |
|
26 item = random.random() |
|
27 data.append(item) |
|
28 heappush(heap, item) |
|
29 self.check_invariant(heap) |
|
30 results = [] |
|
31 while heap: |
|
32 item = heappop(heap) |
|
33 self.check_invariant(heap) |
|
34 results.append(item) |
|
35 data_sorted = data[:] |
|
36 data_sorted.sort() |
|
37 self.assertEqual(data_sorted, results) |
|
38 # 2) Check that the invariant holds for a sorted array |
|
39 self.check_invariant(results) |
|
40 |
|
41 self.assertRaises(TypeError, heappush, []) |
|
42 try: |
|
43 self.assertRaises(TypeError, heappush, None, None) |
|
44 self.assertRaises(TypeError, heappop, None) |
|
45 except AttributeError: |
|
46 pass |
|
47 |
|
48 def check_invariant(self, heap): |
|
49 # Check the heap invariant. |
|
50 for pos, item in enumerate(heap): |
|
51 if pos: # pos 0 has no parent |
|
52 parentpos = (pos-1) >> 1 |
|
53 self.assert_(heap[parentpos] <= item) |
|
54 |
|
55 def test_heapify(self): |
|
56 for size in range(30): |
|
57 heap = [random.random() for dummy in range(size)] |
|
58 heapify(heap) |
|
59 self.check_invariant(heap) |
|
60 |
|
61 self.assertRaises(TypeError, heapify, None) |
|
62 |
|
63 def test_naive_nbest(self): |
|
64 data = [random.randrange(2000) for i in range(1000)] |
|
65 heap = [] |
|
66 for item in data: |
|
67 heappush(heap, item) |
|
68 if len(heap) > 10: |
|
69 heappop(heap) |
|
70 heap.sort() |
|
71 self.assertEqual(heap, sorted(data)[-10:]) |
|
72 |
|
73 def test_nbest(self): |
|
74 # Less-naive "N-best" algorithm, much faster (if len(data) is big |
|
75 # enough <wink>) than sorting all of data. However, if we had a max |
|
76 # heap instead of a min heap, it could go faster still via |
|
77 # heapify'ing all of data (linear time), then doing 10 heappops |
|
78 # (10 log-time steps). |
|
79 data = [random.randrange(2000) for i in range(1000)] |
|
80 heap = data[:10] |
|
81 heapify(heap) |
|
82 for item in data[10:]: |
|
83 if item > heap[0]: # this gets rarer the longer we run |
|
84 heapreplace(heap, item) |
|
85 self.assertEqual(list(heapiter(heap)), sorted(data)[-10:]) |
|
86 |
|
87 self.assertRaises(TypeError, heapreplace, None) |
|
88 self.assertRaises(TypeError, heapreplace, None, None) |
|
89 self.assertRaises(IndexError, heapreplace, [], None) |
|
90 |
|
91 def test_heapsort(self): |
|
92 # Exercise everything with repeated heapsort checks |
|
93 for trial in xrange(100): |
|
94 size = random.randrange(50) |
|
95 data = [random.randrange(25) for i in range(size)] |
|
96 if trial & 1: # Half of the time, use heapify |
|
97 heap = data[:] |
|
98 heapify(heap) |
|
99 else: # The rest of the time, use heappush |
|
100 heap = [] |
|
101 for item in data: |
|
102 heappush(heap, item) |
|
103 heap_sorted = [heappop(heap) for i in range(size)] |
|
104 self.assertEqual(heap_sorted, sorted(data)) |
|
105 |
|
106 def test_nsmallest(self): |
|
107 data = [(random.randrange(2000), i) for i in range(1000)] |
|
108 for f in (None, lambda x: x[0] * 547 % 2000): |
|
109 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): |
|
110 self.assertEqual(nsmallest(n, data), sorted(data)[:n]) |
|
111 self.assertEqual(nsmallest(n, data, key=f), |
|
112 sorted(data, key=f)[:n]) |
|
113 |
|
114 def test_nlargest(self): |
|
115 data = [(random.randrange(2000), i) for i in range(1000)] |
|
116 for f in (None, lambda x: x[0] * 547 % 2000): |
|
117 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): |
|
118 self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) |
|
119 self.assertEqual(nlargest(n, data, key=f), |
|
120 sorted(data, key=f, reverse=True)[:n]) |
|
121 |
|
122 |
|
123 #============================================================================== |
|
124 |
|
125 class LenOnly: |
|
126 "Dummy sequence class defining __len__ but not __getitem__." |
|
127 def __len__(self): |
|
128 return 10 |
|
129 |
|
130 class GetOnly: |
|
131 "Dummy sequence class defining __getitem__ but not __len__." |
|
132 def __getitem__(self, ndx): |
|
133 return 10 |
|
134 |
|
135 class CmpErr: |
|
136 "Dummy element that always raises an error during comparison" |
|
137 def __cmp__(self, other): |
|
138 raise ZeroDivisionError |
|
139 |
|
140 def R(seqn): |
|
141 'Regular generator' |
|
142 for i in seqn: |
|
143 yield i |
|
144 |
|
145 class G: |
|
146 'Sequence using __getitem__' |
|
147 def __init__(self, seqn): |
|
148 self.seqn = seqn |
|
149 def __getitem__(self, i): |
|
150 return self.seqn[i] |
|
151 |
|
152 class I: |
|
153 'Sequence using iterator protocol' |
|
154 def __init__(self, seqn): |
|
155 self.seqn = seqn |
|
156 self.i = 0 |
|
157 def __iter__(self): |
|
158 return self |
|
159 def next(self): |
|
160 if self.i >= len(self.seqn): raise StopIteration |
|
161 v = self.seqn[self.i] |
|
162 self.i += 1 |
|
163 return v |
|
164 |
|
165 class Ig: |
|
166 'Sequence using iterator protocol defined with a generator' |
|
167 def __init__(self, seqn): |
|
168 self.seqn = seqn |
|
169 self.i = 0 |
|
170 def __iter__(self): |
|
171 for val in self.seqn: |
|
172 yield val |
|
173 |
|
174 class X: |
|
175 'Missing __getitem__ and __iter__' |
|
176 def __init__(self, seqn): |
|
177 self.seqn = seqn |
|
178 self.i = 0 |
|
179 def next(self): |
|
180 if self.i >= len(self.seqn): raise StopIteration |
|
181 v = self.seqn[self.i] |
|
182 self.i += 1 |
|
183 return v |
|
184 |
|
185 class N: |
|
186 'Iterator missing next()' |
|
187 def __init__(self, seqn): |
|
188 self.seqn = seqn |
|
189 self.i = 0 |
|
190 def __iter__(self): |
|
191 return self |
|
192 |
|
193 class E: |
|
194 'Test propagation of exceptions' |
|
195 def __init__(self, seqn): |
|
196 self.seqn = seqn |
|
197 self.i = 0 |
|
198 def __iter__(self): |
|
199 return self |
|
200 def next(self): |
|
201 3 // 0 |
|
202 |
|
203 class S: |
|
204 'Test immediate stop' |
|
205 def __init__(self, seqn): |
|
206 pass |
|
207 def __iter__(self): |
|
208 return self |
|
209 def next(self): |
|
210 raise StopIteration |
|
211 |
|
212 from itertools import chain, imap |
|
213 def L(seqn): |
|
214 'Test multiple tiers of iterators' |
|
215 return chain(imap(lambda x:x, R(Ig(G(seqn))))) |
|
216 |
|
217 class TestErrorHandling(unittest.TestCase): |
|
218 |
|
219 def test_non_sequence(self): |
|
220 for f in (heapify, heappop): |
|
221 self.assertRaises(TypeError, f, 10) |
|
222 for f in (heappush, heapreplace, nlargest, nsmallest): |
|
223 self.assertRaises(TypeError, f, 10, 10) |
|
224 |
|
225 def test_len_only(self): |
|
226 for f in (heapify, heappop): |
|
227 self.assertRaises(TypeError, f, LenOnly()) |
|
228 for f in (heappush, heapreplace): |
|
229 self.assertRaises(TypeError, f, LenOnly(), 10) |
|
230 for f in (nlargest, nsmallest): |
|
231 self.assertRaises(TypeError, f, 2, LenOnly()) |
|
232 |
|
233 def test_get_only(self): |
|
234 for f in (heapify, heappop): |
|
235 self.assertRaises(TypeError, f, GetOnly()) |
|
236 for f in (heappush, heapreplace): |
|
237 self.assertRaises(TypeError, f, GetOnly(), 10) |
|
238 for f in (nlargest, nsmallest): |
|
239 self.assertRaises(TypeError, f, 2, GetOnly()) |
|
240 |
|
241 def test_get_only(self): |
|
242 seq = [CmpErr(), CmpErr(), CmpErr()] |
|
243 for f in (heapify, heappop): |
|
244 self.assertRaises(ZeroDivisionError, f, seq) |
|
245 for f in (heappush, heapreplace): |
|
246 self.assertRaises(ZeroDivisionError, f, seq, 10) |
|
247 for f in (nlargest, nsmallest): |
|
248 self.assertRaises(ZeroDivisionError, f, 2, seq) |
|
249 |
|
250 def test_arg_parsing(self): |
|
251 for f in (heapify, heappop, heappush, heapreplace, nlargest, nsmallest): |
|
252 self.assertRaises(TypeError, f, 10) |
|
253 |
|
254 def test_iterable_args(self): |
|
255 for f in (nlargest, nsmallest): |
|
256 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): |
|
257 for g in (G, I, Ig, L, R): |
|
258 self.assertEqual(f(2, g(s)), f(2,s)) |
|
259 self.assertEqual(f(2, S(s)), []) |
|
260 self.assertRaises(TypeError, f, 2, X(s)) |
|
261 self.assertRaises(TypeError, f, 2, N(s)) |
|
262 self.assertRaises(ZeroDivisionError, f, 2, E(s)) |
|
263 |
|
264 #============================================================================== |
|
265 |
|
266 |
|
267 def test_main(verbose=None): |
|
268 from types import BuiltinFunctionType |
|
269 |
|
270 test_classes = [TestHeap] |
|
271 if isinstance(heapify, BuiltinFunctionType): |
|
272 test_classes.append(TestErrorHandling) |
|
273 test_support.run_unittest(*test_classes) |
|
274 |
|
275 # verify reference counting |
|
276 if verbose and hasattr(sys, "gettotalrefcount"): |
|
277 import gc |
|
278 counts = [None] * 5 |
|
279 for i in xrange(len(counts)): |
|
280 test_support.run_unittest(*test_classes) |
|
281 gc.collect() |
|
282 counts[i] = sys.gettotalrefcount() |
|
283 print counts |
|
284 |
|
285 if __name__ == "__main__": |
|
286 test_main(verbose=True) |