python-2.5.2/win32/Lib/test/test_heapq.py
changeset 0 ae805ac0140d
equal deleted inserted replaced
-1:000000000000 0:ae805ac0140d
       
     1 """Unittests for heapq."""
       
     2 
       
     3 from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest
       
     4 import random
       
     5 import unittest
       
     6 from test import test_support
       
     7 import sys
       
     8 
       
     9 
       
    10 def heapiter(heap):
       
    11     # An iterator returning a heap's elements, smallest-first.
       
    12     try:
       
    13         while 1:
       
    14             yield heappop(heap)
       
    15     except IndexError:
       
    16         pass
       
    17 
       
    18 class TestHeap(unittest.TestCase):
       
    19 
       
    20     def test_push_pop(self):
       
    21         # 1) Push 256 random numbers and pop them off, verifying all's OK.
       
    22         heap = []
       
    23         data = []
       
    24         self.check_invariant(heap)
       
    25         for i in range(256):
       
    26             item = random.random()
       
    27             data.append(item)
       
    28             heappush(heap, item)
       
    29             self.check_invariant(heap)
       
    30         results = []
       
    31         while heap:
       
    32             item = heappop(heap)
       
    33             self.check_invariant(heap)
       
    34             results.append(item)
       
    35         data_sorted = data[:]
       
    36         data_sorted.sort()
       
    37         self.assertEqual(data_sorted, results)
       
    38         # 2) Check that the invariant holds for a sorted array
       
    39         self.check_invariant(results)
       
    40 
       
    41         self.assertRaises(TypeError, heappush, [])
       
    42         try:
       
    43             self.assertRaises(TypeError, heappush, None, None)
       
    44             self.assertRaises(TypeError, heappop, None)
       
    45         except AttributeError:
       
    46             pass
       
    47 
       
    48     def check_invariant(self, heap):
       
    49         # Check the heap invariant.
       
    50         for pos, item in enumerate(heap):
       
    51             if pos: # pos 0 has no parent
       
    52                 parentpos = (pos-1) >> 1
       
    53                 self.assert_(heap[parentpos] <= item)
       
    54 
       
    55     def test_heapify(self):
       
    56         for size in range(30):
       
    57             heap = [random.random() for dummy in range(size)]
       
    58             heapify(heap)
       
    59             self.check_invariant(heap)
       
    60 
       
    61         self.assertRaises(TypeError, heapify, None)
       
    62 
       
    63     def test_naive_nbest(self):
       
    64         data = [random.randrange(2000) for i in range(1000)]
       
    65         heap = []
       
    66         for item in data:
       
    67             heappush(heap, item)
       
    68             if len(heap) > 10:
       
    69                 heappop(heap)
       
    70         heap.sort()
       
    71         self.assertEqual(heap, sorted(data)[-10:])
       
    72 
       
    73     def test_nbest(self):
       
    74         # Less-naive "N-best" algorithm, much faster (if len(data) is big
       
    75         # enough <wink>) than sorting all of data.  However, if we had a max
       
    76         # heap instead of a min heap, it could go faster still via
       
    77         # heapify'ing all of data (linear time), then doing 10 heappops
       
    78         # (10 log-time steps).
       
    79         data = [random.randrange(2000) for i in range(1000)]
       
    80         heap = data[:10]
       
    81         heapify(heap)
       
    82         for item in data[10:]:
       
    83             if item > heap[0]:  # this gets rarer the longer we run
       
    84                 heapreplace(heap, item)
       
    85         self.assertEqual(list(heapiter(heap)), sorted(data)[-10:])
       
    86 
       
    87         self.assertRaises(TypeError, heapreplace, None)
       
    88         self.assertRaises(TypeError, heapreplace, None, None)
       
    89         self.assertRaises(IndexError, heapreplace, [], None)
       
    90 
       
    91     def test_heapsort(self):
       
    92         # Exercise everything with repeated heapsort checks
       
    93         for trial in xrange(100):
       
    94             size = random.randrange(50)
       
    95             data = [random.randrange(25) for i in range(size)]
       
    96             if trial & 1:     # Half of the time, use heapify
       
    97                 heap = data[:]
       
    98                 heapify(heap)
       
    99             else:             # The rest of the time, use heappush
       
   100                 heap = []
       
   101                 for item in data:
       
   102                     heappush(heap, item)
       
   103             heap_sorted = [heappop(heap) for i in range(size)]
       
   104             self.assertEqual(heap_sorted, sorted(data))
       
   105 
       
   106     def test_nsmallest(self):
       
   107         data = [(random.randrange(2000), i) for i in range(1000)]
       
   108         for f in (None, lambda x:  x[0] * 547 % 2000):
       
   109             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
       
   110                 self.assertEqual(nsmallest(n, data), sorted(data)[:n])
       
   111                 self.assertEqual(nsmallest(n, data, key=f),
       
   112                                  sorted(data, key=f)[:n])
       
   113 
       
   114     def test_nlargest(self):
       
   115         data = [(random.randrange(2000), i) for i in range(1000)]
       
   116         for f in (None, lambda x:  x[0] * 547 % 2000):
       
   117             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
       
   118                 self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n])
       
   119                 self.assertEqual(nlargest(n, data, key=f),
       
   120                                  sorted(data, key=f, reverse=True)[:n])
       
   121 
       
   122 
       
   123 #==============================================================================
       
   124 
       
   125 class LenOnly:
       
   126     "Dummy sequence class defining __len__ but not __getitem__."
       
   127     def __len__(self):
       
   128         return 10
       
   129 
       
   130 class GetOnly:
       
   131     "Dummy sequence class defining __getitem__ but not __len__."
       
   132     def __getitem__(self, ndx):
       
   133         return 10
       
   134 
       
   135 class CmpErr:
       
   136     "Dummy element that always raises an error during comparison"
       
   137     def __cmp__(self, other):
       
   138         raise ZeroDivisionError
       
   139 
       
   140 def R(seqn):
       
   141     'Regular generator'
       
   142     for i in seqn:
       
   143         yield i
       
   144 
       
   145 class G:
       
   146     'Sequence using __getitem__'
       
   147     def __init__(self, seqn):
       
   148         self.seqn = seqn
       
   149     def __getitem__(self, i):
       
   150         return self.seqn[i]
       
   151 
       
   152 class I:
       
   153     'Sequence using iterator protocol'
       
   154     def __init__(self, seqn):
       
   155         self.seqn = seqn
       
   156         self.i = 0
       
   157     def __iter__(self):
       
   158         return self
       
   159     def next(self):
       
   160         if self.i >= len(self.seqn): raise StopIteration
       
   161         v = self.seqn[self.i]
       
   162         self.i += 1
       
   163         return v
       
   164 
       
   165 class Ig:
       
   166     'Sequence using iterator protocol defined with a generator'
       
   167     def __init__(self, seqn):
       
   168         self.seqn = seqn
       
   169         self.i = 0
       
   170     def __iter__(self):
       
   171         for val in self.seqn:
       
   172             yield val
       
   173 
       
   174 class X:
       
   175     'Missing __getitem__ and __iter__'
       
   176     def __init__(self, seqn):
       
   177         self.seqn = seqn
       
   178         self.i = 0
       
   179     def next(self):
       
   180         if self.i >= len(self.seqn): raise StopIteration
       
   181         v = self.seqn[self.i]
       
   182         self.i += 1
       
   183         return v
       
   184 
       
   185 class N:
       
   186     'Iterator missing next()'
       
   187     def __init__(self, seqn):
       
   188         self.seqn = seqn
       
   189         self.i = 0
       
   190     def __iter__(self):
       
   191         return self
       
   192 
       
   193 class E:
       
   194     'Test propagation of exceptions'
       
   195     def __init__(self, seqn):
       
   196         self.seqn = seqn
       
   197         self.i = 0
       
   198     def __iter__(self):
       
   199         return self
       
   200     def next(self):
       
   201         3 // 0
       
   202 
       
   203 class S:
       
   204     'Test immediate stop'
       
   205     def __init__(self, seqn):
       
   206         pass
       
   207     def __iter__(self):
       
   208         return self
       
   209     def next(self):
       
   210         raise StopIteration
       
   211 
       
   212 from itertools import chain, imap
       
   213 def L(seqn):
       
   214     'Test multiple tiers of iterators'
       
   215     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
       
   216 
       
   217 class TestErrorHandling(unittest.TestCase):
       
   218 
       
   219     def test_non_sequence(self):
       
   220         for f in (heapify, heappop):
       
   221             self.assertRaises(TypeError, f, 10)
       
   222         for f in (heappush, heapreplace, nlargest, nsmallest):
       
   223             self.assertRaises(TypeError, f, 10, 10)
       
   224 
       
   225     def test_len_only(self):
       
   226         for f in (heapify, heappop):
       
   227             self.assertRaises(TypeError, f, LenOnly())
       
   228         for f in (heappush, heapreplace):
       
   229             self.assertRaises(TypeError, f, LenOnly(), 10)
       
   230         for f in (nlargest, nsmallest):
       
   231             self.assertRaises(TypeError, f, 2, LenOnly())
       
   232 
       
   233     def test_get_only(self):
       
   234         for f in (heapify, heappop):
       
   235             self.assertRaises(TypeError, f, GetOnly())
       
   236         for f in (heappush, heapreplace):
       
   237             self.assertRaises(TypeError, f, GetOnly(), 10)
       
   238         for f in (nlargest, nsmallest):
       
   239             self.assertRaises(TypeError, f, 2, GetOnly())
       
   240 
       
   241     def test_get_only(self):
       
   242         seq = [CmpErr(), CmpErr(), CmpErr()]
       
   243         for f in (heapify, heappop):
       
   244             self.assertRaises(ZeroDivisionError, f, seq)
       
   245         for f in (heappush, heapreplace):
       
   246             self.assertRaises(ZeroDivisionError, f, seq, 10)
       
   247         for f in (nlargest, nsmallest):
       
   248             self.assertRaises(ZeroDivisionError, f, 2, seq)
       
   249 
       
   250     def test_arg_parsing(self):
       
   251         for f in (heapify, heappop, heappush, heapreplace, nlargest, nsmallest):
       
   252             self.assertRaises(TypeError, f, 10)
       
   253 
       
   254     def test_iterable_args(self):
       
   255         for f in  (nlargest, nsmallest):
       
   256             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
       
   257                 for g in (G, I, Ig, L, R):
       
   258                     self.assertEqual(f(2, g(s)), f(2,s))
       
   259                 self.assertEqual(f(2, S(s)), [])
       
   260                 self.assertRaises(TypeError, f, 2, X(s))
       
   261                 self.assertRaises(TypeError, f, 2, N(s))
       
   262                 self.assertRaises(ZeroDivisionError, f, 2, E(s))
       
   263 
       
   264 #==============================================================================
       
   265 
       
   266 
       
   267 def test_main(verbose=None):
       
   268     from types import BuiltinFunctionType
       
   269 
       
   270     test_classes = [TestHeap]
       
   271     if isinstance(heapify, BuiltinFunctionType):
       
   272         test_classes.append(TestErrorHandling)
       
   273     test_support.run_unittest(*test_classes)
       
   274 
       
   275     # verify reference counting
       
   276     if verbose and hasattr(sys, "gettotalrefcount"):
       
   277         import gc
       
   278         counts = [None] * 5
       
   279         for i in xrange(len(counts)):
       
   280             test_support.run_unittest(*test_classes)
       
   281             gc.collect()
       
   282             counts[i] = sys.gettotalrefcount()
       
   283         print counts
       
   284 
       
   285 if __name__ == "__main__":
       
   286     test_main(verbose=True)