symbian-qemu-0.9.1-12/python-2.6.1/Lib/test/test_heapq.py
changeset 1 2fb8b9db1c86
equal deleted inserted replaced
0:ffa851df0825 1:2fb8b9db1c86
       
     1 """Unittests for heapq."""
       
     2 
       
     3 import random
       
     4 import unittest
       
     5 from test import test_support
       
     6 import sys
       
     7 
       
     8 # We do a bit of trickery here to be able to test both the C implementation
       
     9 # and the Python implementation of the module.
       
    10 
       
    11 # Make it impossible to import the C implementation anymore.
       
    12 sys.modules['_heapq'] = 0
       
    13 # We must also handle the case that heapq was imported before.
       
    14 if 'heapq' in sys.modules:
       
    15     del sys.modules['heapq']
       
    16 
       
    17 # Now we can import the module and get the pure Python implementation.
       
    18 import heapq as py_heapq
       
    19 
       
    20 # Restore everything to normal.
       
    21 del sys.modules['_heapq']
       
    22 del sys.modules['heapq']
       
    23 
       
    24 # This is now the module with the C implementation.
       
    25 import heapq as c_heapq
       
    26 
       
    27 
       
    28 class TestHeap(unittest.TestCase):
       
    29     module = None
       
    30 
       
    31     def test_push_pop(self):
       
    32         # 1) Push 256 random numbers and pop them off, verifying all's OK.
       
    33         heap = []
       
    34         data = []
       
    35         self.check_invariant(heap)
       
    36         for i in range(256):
       
    37             item = random.random()
       
    38             data.append(item)
       
    39             self.module.heappush(heap, item)
       
    40             self.check_invariant(heap)
       
    41         results = []
       
    42         while heap:
       
    43             item = self.module.heappop(heap)
       
    44             self.check_invariant(heap)
       
    45             results.append(item)
       
    46         data_sorted = data[:]
       
    47         data_sorted.sort()
       
    48         self.assertEqual(data_sorted, results)
       
    49         # 2) Check that the invariant holds for a sorted array
       
    50         self.check_invariant(results)
       
    51 
       
    52         self.assertRaises(TypeError, self.module.heappush, [])
       
    53         try:
       
    54             self.assertRaises(TypeError, self.module.heappush, None, None)
       
    55             self.assertRaises(TypeError, self.module.heappop, None)
       
    56         except AttributeError:
       
    57             pass
       
    58 
       
    59     def check_invariant(self, heap):
       
    60         # Check the heap invariant.
       
    61         for pos, item in enumerate(heap):
       
    62             if pos: # pos 0 has no parent
       
    63                 parentpos = (pos-1) >> 1
       
    64                 self.assert_(heap[parentpos] <= item)
       
    65 
       
    66     def test_heapify(self):
       
    67         for size in range(30):
       
    68             heap = [random.random() for dummy in range(size)]
       
    69             self.module.heapify(heap)
       
    70             self.check_invariant(heap)
       
    71 
       
    72         self.assertRaises(TypeError, self.module.heapify, None)
       
    73 
       
    74     def test_naive_nbest(self):
       
    75         data = [random.randrange(2000) for i in range(1000)]
       
    76         heap = []
       
    77         for item in data:
       
    78             self.module.heappush(heap, item)
       
    79             if len(heap) > 10:
       
    80                 self.module.heappop(heap)
       
    81         heap.sort()
       
    82         self.assertEqual(heap, sorted(data)[-10:])
       
    83 
       
    84     def heapiter(self, heap):
       
    85         # An iterator returning a heap's elements, smallest-first.
       
    86         try:
       
    87             while 1:
       
    88                 yield self.module.heappop(heap)
       
    89         except IndexError:
       
    90             pass
       
    91 
       
    92     def test_nbest(self):
       
    93         # Less-naive "N-best" algorithm, much faster (if len(data) is big
       
    94         # enough <wink>) than sorting all of data.  However, if we had a max
       
    95         # heap instead of a min heap, it could go faster still via
       
    96         # heapify'ing all of data (linear time), then doing 10 heappops
       
    97         # (10 log-time steps).
       
    98         data = [random.randrange(2000) for i in range(1000)]
       
    99         heap = data[:10]
       
   100         self.module.heapify(heap)
       
   101         for item in data[10:]:
       
   102             if item > heap[0]:  # this gets rarer the longer we run
       
   103                 self.module.heapreplace(heap, item)
       
   104         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
       
   105 
       
   106         self.assertRaises(TypeError, self.module.heapreplace, None)
       
   107         self.assertRaises(TypeError, self.module.heapreplace, None, None)
       
   108         self.assertRaises(IndexError, self.module.heapreplace, [], None)
       
   109 
       
   110     def test_nbest_with_pushpop(self):
       
   111         data = [random.randrange(2000) for i in range(1000)]
       
   112         heap = data[:10]
       
   113         self.module.heapify(heap)
       
   114         for item in data[10:]:
       
   115             self.module.heappushpop(heap, item)
       
   116         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
       
   117         self.assertEqual(self.module.heappushpop([], 'x'), 'x')
       
   118 
       
   119     def test_heappushpop(self):
       
   120         h = []
       
   121         x = self.module.heappushpop(h, 10)
       
   122         self.assertEqual((h, x), ([], 10))
       
   123 
       
   124         h = [10]
       
   125         x = self.module.heappushpop(h, 10.0)
       
   126         self.assertEqual((h, x), ([10], 10.0))
       
   127         self.assertEqual(type(h[0]), int)
       
   128         self.assertEqual(type(x), float)
       
   129 
       
   130         h = [10];
       
   131         x = self.module.heappushpop(h, 9)
       
   132         self.assertEqual((h, x), ([10], 9))
       
   133 
       
   134         h = [10];
       
   135         x = self.module.heappushpop(h, 11)
       
   136         self.assertEqual((h, x), ([11], 10))
       
   137 
       
   138     def test_heapsort(self):
       
   139         # Exercise everything with repeated heapsort checks
       
   140         for trial in xrange(100):
       
   141             size = random.randrange(50)
       
   142             data = [random.randrange(25) for i in range(size)]
       
   143             if trial & 1:     # Half of the time, use heapify
       
   144                 heap = data[:]
       
   145                 self.module.heapify(heap)
       
   146             else:             # The rest of the time, use heappush
       
   147                 heap = []
       
   148                 for item in data:
       
   149                     self.module.heappush(heap, item)
       
   150             heap_sorted = [self.module.heappop(heap) for i in range(size)]
       
   151             self.assertEqual(heap_sorted, sorted(data))
       
   152 
       
   153     def test_merge(self):
       
   154         inputs = []
       
   155         for i in xrange(random.randrange(5)):
       
   156             row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
       
   157             inputs.append(row)
       
   158         self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
       
   159         self.assertEqual(list(self.module.merge()), [])
       
   160 
       
   161     def test_merge_stability(self):
       
   162         class Int(int):
       
   163             pass
       
   164         inputs = [[], [], [], []]
       
   165         for i in range(20000):
       
   166             stream = random.randrange(4)
       
   167             x = random.randrange(500)
       
   168             obj = Int(x)
       
   169             obj.pair = (x, stream)
       
   170             inputs[stream].append(obj)
       
   171         for stream in inputs:
       
   172             stream.sort()
       
   173         result = [i.pair for i in self.module.merge(*inputs)]
       
   174         self.assertEqual(result, sorted(result))
       
   175 
       
   176     def test_nsmallest(self):
       
   177         data = [(random.randrange(2000), i) for i in range(1000)]
       
   178         for f in (None, lambda x:  x[0] * 547 % 2000):
       
   179             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
       
   180                 self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
       
   181                 self.assertEqual(self.module.nsmallest(n, data, key=f),
       
   182                                  sorted(data, key=f)[:n])
       
   183 
       
   184     def test_nlargest(self):
       
   185         data = [(random.randrange(2000), i) for i in range(1000)]
       
   186         for f in (None, lambda x:  x[0] * 547 % 2000):
       
   187             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
       
   188                 self.assertEqual(self.module.nlargest(n, data),
       
   189                                  sorted(data, reverse=True)[:n])
       
   190                 self.assertEqual(self.module.nlargest(n, data, key=f),
       
   191                                  sorted(data, key=f, reverse=True)[:n])
       
   192 
       
   193 class TestHeapPython(TestHeap):
       
   194     module = py_heapq
       
   195 
       
   196 class TestHeapC(TestHeap):
       
   197     module = c_heapq
       
   198 
       
   199     def test_comparison_operator(self):
       
   200         # Issue 3501: Make sure heapq works with both __lt__ and __le__
       
   201         def hsort(data, comp):
       
   202             data = map(comp, data)
       
   203             self.module.heapify(data)
       
   204             return [self.module.heappop(data).x for i in range(len(data))]
       
   205         class LT:
       
   206             def __init__(self, x):
       
   207                 self.x = x
       
   208             def __lt__(self, other):
       
   209                 return self.x > other.x
       
   210         class LE:
       
   211             def __init__(self, x):
       
   212                 self.x = x
       
   213             def __le__(self, other):
       
   214                 return self.x >= other.x
       
   215         data = [random.random() for i in range(100)]
       
   216         target = sorted(data, reverse=True)
       
   217         self.assertEqual(hsort(data, LT), target)
       
   218         self.assertEqual(hsort(data, LE), target)
       
   219 
       
   220 
       
   221 #==============================================================================
       
   222 
       
   223 class LenOnly:
       
   224     "Dummy sequence class defining __len__ but not __getitem__."
       
   225     def __len__(self):
       
   226         return 10
       
   227 
       
   228 class GetOnly:
       
   229     "Dummy sequence class defining __getitem__ but not __len__."
       
   230     def __getitem__(self, ndx):
       
   231         return 10
       
   232 
       
   233 class CmpErr:
       
   234     "Dummy element that always raises an error during comparison"
       
   235     def __cmp__(self, other):
       
   236         raise ZeroDivisionError
       
   237 
       
   238 def R(seqn):
       
   239     'Regular generator'
       
   240     for i in seqn:
       
   241         yield i
       
   242 
       
   243 class G:
       
   244     'Sequence using __getitem__'
       
   245     def __init__(self, seqn):
       
   246         self.seqn = seqn
       
   247     def __getitem__(self, i):
       
   248         return self.seqn[i]
       
   249 
       
   250 class I:
       
   251     'Sequence using iterator protocol'
       
   252     def __init__(self, seqn):
       
   253         self.seqn = seqn
       
   254         self.i = 0
       
   255     def __iter__(self):
       
   256         return self
       
   257     def next(self):
       
   258         if self.i >= len(self.seqn): raise StopIteration
       
   259         v = self.seqn[self.i]
       
   260         self.i += 1
       
   261         return v
       
   262 
       
   263 class Ig:
       
   264     'Sequence using iterator protocol defined with a generator'
       
   265     def __init__(self, seqn):
       
   266         self.seqn = seqn
       
   267         self.i = 0
       
   268     def __iter__(self):
       
   269         for val in self.seqn:
       
   270             yield val
       
   271 
       
   272 class X:
       
   273     'Missing __getitem__ and __iter__'
       
   274     def __init__(self, seqn):
       
   275         self.seqn = seqn
       
   276         self.i = 0
       
   277     def next(self):
       
   278         if self.i >= len(self.seqn): raise StopIteration
       
   279         v = self.seqn[self.i]
       
   280         self.i += 1
       
   281         return v
       
   282 
       
   283 class N:
       
   284     'Iterator missing next()'
       
   285     def __init__(self, seqn):
       
   286         self.seqn = seqn
       
   287         self.i = 0
       
   288     def __iter__(self):
       
   289         return self
       
   290 
       
   291 class E:
       
   292     'Test propagation of exceptions'
       
   293     def __init__(self, seqn):
       
   294         self.seqn = seqn
       
   295         self.i = 0
       
   296     def __iter__(self):
       
   297         return self
       
   298     def next(self):
       
   299         3 // 0
       
   300 
       
   301 class S:
       
   302     'Test immediate stop'
       
   303     def __init__(self, seqn):
       
   304         pass
       
   305     def __iter__(self):
       
   306         return self
       
   307     def next(self):
       
   308         raise StopIteration
       
   309 
       
   310 from itertools import chain, imap
       
   311 def L(seqn):
       
   312     'Test multiple tiers of iterators'
       
   313     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
       
   314 
       
   315 class TestErrorHandling(unittest.TestCase):
       
   316     # only for C implementation
       
   317     module = c_heapq
       
   318 
       
   319     def test_non_sequence(self):
       
   320         for f in (self.module.heapify, self.module.heappop):
       
   321             self.assertRaises(TypeError, f, 10)
       
   322         for f in (self.module.heappush, self.module.heapreplace,
       
   323                   self.module.nlargest, self.module.nsmallest):
       
   324             self.assertRaises(TypeError, f, 10, 10)
       
   325 
       
   326     def test_len_only(self):
       
   327         for f in (self.module.heapify, self.module.heappop):
       
   328             self.assertRaises(TypeError, f, LenOnly())
       
   329         for f in (self.module.heappush, self.module.heapreplace):
       
   330             self.assertRaises(TypeError, f, LenOnly(), 10)
       
   331         for f in (self.module.nlargest, self.module.nsmallest):
       
   332             self.assertRaises(TypeError, f, 2, LenOnly())
       
   333 
       
   334     def test_get_only(self):
       
   335         for f in (self.module.heapify, self.module.heappop):
       
   336             self.assertRaises(TypeError, f, GetOnly())
       
   337         for f in (self.module.heappush, self.module.heapreplace):
       
   338             self.assertRaises(TypeError, f, GetOnly(), 10)
       
   339         for f in (self.module.nlargest, self.module.nsmallest):
       
   340             self.assertRaises(TypeError, f, 2, GetOnly())
       
   341 
       
   342     def test_get_only(self):
       
   343         seq = [CmpErr(), CmpErr(), CmpErr()]
       
   344         for f in (self.module.heapify, self.module.heappop):
       
   345             self.assertRaises(ZeroDivisionError, f, seq)
       
   346         for f in (self.module.heappush, self.module.heapreplace):
       
   347             self.assertRaises(ZeroDivisionError, f, seq, 10)
       
   348         for f in (self.module.nlargest, self.module.nsmallest):
       
   349             self.assertRaises(ZeroDivisionError, f, 2, seq)
       
   350 
       
   351     def test_arg_parsing(self):
       
   352         for f in (self.module.heapify, self.module.heappop,
       
   353                   self.module.heappush, self.module.heapreplace,
       
   354                   self.module.nlargest, self.module.nsmallest):
       
   355             self.assertRaises(TypeError, f, 10)
       
   356 
       
   357     def test_iterable_args(self):
       
   358         for f in (self.module.nlargest, self.module.nsmallest):
       
   359             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
       
   360                 for g in (G, I, Ig, L, R):
       
   361                     self.assertEqual(f(2, g(s)), f(2,s))
       
   362                 self.assertEqual(f(2, S(s)), [])
       
   363                 self.assertRaises(TypeError, f, 2, X(s))
       
   364                 self.assertRaises(TypeError, f, 2, N(s))
       
   365                 self.assertRaises(ZeroDivisionError, f, 2, E(s))
       
   366 
       
   367 
       
   368 #==============================================================================
       
   369 
       
   370 
       
   371 def test_main(verbose=None):
       
   372     from types import BuiltinFunctionType
       
   373 
       
   374     test_classes = [TestHeapPython, TestHeapC, TestErrorHandling]
       
   375     test_support.run_unittest(*test_classes)
       
   376 
       
   377     # verify reference counting
       
   378     if verbose and hasattr(sys, "gettotalrefcount"):
       
   379         import gc
       
   380         counts = [None] * 5
       
   381         for i in xrange(len(counts)):
       
   382             test_support.run_unittest(*test_classes)
       
   383             gc.collect()
       
   384             counts[i] = sys.gettotalrefcount()
       
   385         print counts
       
   386 
       
   387 if __name__ == "__main__":
       
   388     test_main(verbose=True)