python-2.5.2/win32/Lib/test/test_richcmp.py
changeset 0 ae805ac0140d
equal deleted inserted replaced
-1:000000000000 0:ae805ac0140d
       
     1 # Tests for rich comparisons
       
     2 
       
     3 import unittest
       
     4 from test import test_support
       
     5 
       
     6 import operator
       
     7 
       
     8 class Number:
       
     9 
       
    10     def __init__(self, x):
       
    11         self.x = x
       
    12 
       
    13     def __lt__(self, other):
       
    14         return self.x < other
       
    15 
       
    16     def __le__(self, other):
       
    17         return self.x <= other
       
    18 
       
    19     def __eq__(self, other):
       
    20         return self.x == other
       
    21 
       
    22     def __ne__(self, other):
       
    23         return self.x != other
       
    24 
       
    25     def __gt__(self, other):
       
    26         return self.x > other
       
    27 
       
    28     def __ge__(self, other):
       
    29         return self.x >= other
       
    30 
       
    31     def __cmp__(self, other):
       
    32         raise test_support.TestFailed, "Number.__cmp__() should not be called"
       
    33 
       
    34     def __repr__(self):
       
    35         return "Number(%r)" % (self.x, )
       
    36 
       
    37 class Vector:
       
    38 
       
    39     def __init__(self, data):
       
    40         self.data = data
       
    41 
       
    42     def __len__(self):
       
    43         return len(self.data)
       
    44 
       
    45     def __getitem__(self, i):
       
    46         return self.data[i]
       
    47 
       
    48     def __setitem__(self, i, v):
       
    49         self.data[i] = v
       
    50 
       
    51     def __hash__(self):
       
    52         raise TypeError, "Vectors cannot be hashed"
       
    53 
       
    54     def __nonzero__(self):
       
    55         raise TypeError, "Vectors cannot be used in Boolean contexts"
       
    56 
       
    57     def __cmp__(self, other):
       
    58         raise test_support.TestFailed, "Vector.__cmp__() should not be called"
       
    59 
       
    60     def __repr__(self):
       
    61         return "Vector(%r)" % (self.data, )
       
    62 
       
    63     def __lt__(self, other):
       
    64         return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
       
    65 
       
    66     def __le__(self, other):
       
    67         return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
       
    68 
       
    69     def __eq__(self, other):
       
    70         return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
       
    71 
       
    72     def __ne__(self, other):
       
    73         return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
       
    74 
       
    75     def __gt__(self, other):
       
    76         return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
       
    77 
       
    78     def __ge__(self, other):
       
    79         return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
       
    80 
       
    81     def __cast(self, other):
       
    82         if isinstance(other, Vector):
       
    83             other = other.data
       
    84         if len(self.data) != len(other):
       
    85             raise ValueError, "Cannot compare vectors of different length"
       
    86         return other
       
    87 
       
    88 opmap = {
       
    89     "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
       
    90     "le": (lambda a,b: a<=b, operator.le, operator.__le__),
       
    91     "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
       
    92     "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
       
    93     "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
       
    94     "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
       
    95 }
       
    96 
       
    97 class VectorTest(unittest.TestCase):
       
    98 
       
    99     def checkfail(self, error, opname, *args):
       
   100         for op in opmap[opname]:
       
   101             self.assertRaises(error, op, *args)
       
   102 
       
   103     def checkequal(self, opname, a, b, expres):
       
   104         for op in opmap[opname]:
       
   105             realres = op(a, b)
       
   106             # can't use assertEqual(realres, expres) here
       
   107             self.assertEqual(len(realres), len(expres))
       
   108             for i in xrange(len(realres)):
       
   109                 # results are bool, so we can use "is" here
       
   110                 self.assert_(realres[i] is expres[i])
       
   111 
       
   112     def test_mixed(self):
       
   113         # check that comparisons involving Vector objects
       
   114         # which return rich results (i.e. Vectors with itemwise
       
   115         # comparison results) work
       
   116         a = Vector(range(2))
       
   117         b = Vector(range(3))
       
   118         # all comparisons should fail for different length
       
   119         for opname in opmap:
       
   120             self.checkfail(ValueError, opname, a, b)
       
   121 
       
   122         a = range(5)
       
   123         b = 5 * [2]
       
   124         # try mixed arguments (but not (a, b) as that won't return a bool vector)
       
   125         args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
       
   126         for (a, b) in args:
       
   127             self.checkequal("lt", a, b, [True,  True,  False, False, False])
       
   128             self.checkequal("le", a, b, [True,  True,  True,  False, False])
       
   129             self.checkequal("eq", a, b, [False, False, True,  False, False])
       
   130             self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
       
   131             self.checkequal("gt", a, b, [False, False, False, True,  True ])
       
   132             self.checkequal("ge", a, b, [False, False, True,  True,  True ])
       
   133 
       
   134             for ops in opmap.itervalues():
       
   135                 for op in ops:
       
   136                     # calls __nonzero__, which should fail
       
   137                     self.assertRaises(TypeError, bool, op(a, b))
       
   138 
       
   139 class NumberTest(unittest.TestCase):
       
   140 
       
   141     def test_basic(self):
       
   142         # Check that comparisons involving Number objects
       
   143         # give the same results give as comparing the
       
   144         # corresponding ints
       
   145         for a in xrange(3):
       
   146             for b in xrange(3):
       
   147                 for typea in (int, Number):
       
   148                     for typeb in (int, Number):
       
   149                         if typea==typeb==int:
       
   150                             continue # the combination int, int is useless
       
   151                         ta = typea(a)
       
   152                         tb = typeb(b)
       
   153                         for ops in opmap.itervalues():
       
   154                             for op in ops:
       
   155                                 realoutcome = op(a, b)
       
   156                                 testoutcome = op(ta, tb)
       
   157                                 self.assertEqual(realoutcome, testoutcome)
       
   158 
       
   159     def checkvalue(self, opname, a, b, expres):
       
   160         for typea in (int, Number):
       
   161             for typeb in (int, Number):
       
   162                 ta = typea(a)
       
   163                 tb = typeb(b)
       
   164                 for op in opmap[opname]:
       
   165                     realres = op(ta, tb)
       
   166                     realres = getattr(realres, "x", realres)
       
   167                     self.assert_(realres is expres)
       
   168 
       
   169     def test_values(self):
       
   170         # check all operators and all comparison results
       
   171         self.checkvalue("lt", 0, 0, False)
       
   172         self.checkvalue("le", 0, 0, True )
       
   173         self.checkvalue("eq", 0, 0, True )
       
   174         self.checkvalue("ne", 0, 0, False)
       
   175         self.checkvalue("gt", 0, 0, False)
       
   176         self.checkvalue("ge", 0, 0, True )
       
   177 
       
   178         self.checkvalue("lt", 0, 1, True )
       
   179         self.checkvalue("le", 0, 1, True )
       
   180         self.checkvalue("eq", 0, 1, False)
       
   181         self.checkvalue("ne", 0, 1, True )
       
   182         self.checkvalue("gt", 0, 1, False)
       
   183         self.checkvalue("ge", 0, 1, False)
       
   184 
       
   185         self.checkvalue("lt", 1, 0, False)
       
   186         self.checkvalue("le", 1, 0, False)
       
   187         self.checkvalue("eq", 1, 0, False)
       
   188         self.checkvalue("ne", 1, 0, True )
       
   189         self.checkvalue("gt", 1, 0, True )
       
   190         self.checkvalue("ge", 1, 0, True )
       
   191 
       
   192 class MiscTest(unittest.TestCase):
       
   193 
       
   194     def test_misbehavin(self):
       
   195         class Misb:
       
   196             def __lt__(self, other): return 0
       
   197             def __gt__(self, other): return 0
       
   198             def __eq__(self, other): return 0
       
   199             def __le__(self, other): raise TestFailed, "This shouldn't happen"
       
   200             def __ge__(self, other): raise TestFailed, "This shouldn't happen"
       
   201             def __ne__(self, other): raise TestFailed, "This shouldn't happen"
       
   202             def __cmp__(self, other): raise RuntimeError, "expected"
       
   203         a = Misb()
       
   204         b = Misb()
       
   205         self.assertEqual(a<b, 0)
       
   206         self.assertEqual(a==b, 0)
       
   207         self.assertEqual(a>b, 0)
       
   208         self.assertRaises(RuntimeError, cmp, a, b)
       
   209 
       
   210     def test_not(self):
       
   211         # Check that exceptions in __nonzero__ are properly
       
   212         # propagated by the not operator
       
   213         import operator
       
   214         class Exc(Exception):
       
   215             pass
       
   216         class Bad:
       
   217             def __nonzero__(self):
       
   218                 raise Exc
       
   219 
       
   220         def do(bad):
       
   221             not bad
       
   222 
       
   223         for func in (do, operator.not_):
       
   224             self.assertRaises(Exc, func, Bad())
       
   225 
       
   226     def test_recursion(self):
       
   227         # Check that comparison for recursive objects fails gracefully
       
   228         from UserList import UserList
       
   229         a = UserList()
       
   230         b = UserList()
       
   231         a.append(b)
       
   232         b.append(a)
       
   233         self.assertRaises(RuntimeError, operator.eq, a, b)
       
   234         self.assertRaises(RuntimeError, operator.ne, a, b)
       
   235         self.assertRaises(RuntimeError, operator.lt, a, b)
       
   236         self.assertRaises(RuntimeError, operator.le, a, b)
       
   237         self.assertRaises(RuntimeError, operator.gt, a, b)
       
   238         self.assertRaises(RuntimeError, operator.ge, a, b)
       
   239 
       
   240         b.append(17)
       
   241         # Even recursive lists of different lengths are different,
       
   242         # but they cannot be ordered
       
   243         self.assert_(not (a == b))
       
   244         self.assert_(a != b)
       
   245         self.assertRaises(RuntimeError, operator.lt, a, b)
       
   246         self.assertRaises(RuntimeError, operator.le, a, b)
       
   247         self.assertRaises(RuntimeError, operator.gt, a, b)
       
   248         self.assertRaises(RuntimeError, operator.ge, a, b)
       
   249         a.append(17)
       
   250         self.assertRaises(RuntimeError, operator.eq, a, b)
       
   251         self.assertRaises(RuntimeError, operator.ne, a, b)
       
   252         a.insert(0, 11)
       
   253         b.insert(0, 12)
       
   254         self.assert_(not (a == b))
       
   255         self.assert_(a != b)
       
   256         self.assert_(a < b)
       
   257 
       
   258 class DictTest(unittest.TestCase):
       
   259 
       
   260     def test_dicts(self):
       
   261         # Verify that __eq__ and __ne__ work for dicts even if the keys and
       
   262         # values don't support anything other than __eq__ and __ne__ (and
       
   263         # __hash__).  Complex numbers are a fine example of that.
       
   264         import random
       
   265         imag1a = {}
       
   266         for i in range(50):
       
   267             imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
       
   268         items = imag1a.items()
       
   269         random.shuffle(items)
       
   270         imag1b = {}
       
   271         for k, v in items:
       
   272             imag1b[k] = v
       
   273         imag2 = imag1b.copy()
       
   274         imag2[k] = v + 1.0
       
   275         self.assert_(imag1a == imag1a)
       
   276         self.assert_(imag1a == imag1b)
       
   277         self.assert_(imag2 == imag2)
       
   278         self.assert_(imag1a != imag2)
       
   279         for opname in ("lt", "le", "gt", "ge"):
       
   280             for op in opmap[opname]:
       
   281                 self.assertRaises(TypeError, op, imag1a, imag2)
       
   282 
       
   283 class ListTest(unittest.TestCase):
       
   284 
       
   285     def assertIs(self, a, b):
       
   286         self.assert_(a is b)
       
   287 
       
   288     def test_coverage(self):
       
   289         # exercise all comparisons for lists
       
   290         x = [42]
       
   291         self.assertIs(x<x, False)
       
   292         self.assertIs(x<=x, True)
       
   293         self.assertIs(x==x, True)
       
   294         self.assertIs(x!=x, False)
       
   295         self.assertIs(x>x, False)
       
   296         self.assertIs(x>=x, True)
       
   297         y = [42, 42]
       
   298         self.assertIs(x<y, True)
       
   299         self.assertIs(x<=y, True)
       
   300         self.assertIs(x==y, False)
       
   301         self.assertIs(x!=y, True)
       
   302         self.assertIs(x>y, False)
       
   303         self.assertIs(x>=y, False)
       
   304 
       
   305     def test_badentry(self):
       
   306         # make sure that exceptions for item comparison are properly
       
   307         # propagated in list comparisons
       
   308         class Exc(Exception):
       
   309             pass
       
   310         class Bad:
       
   311             def __eq__(self, other):
       
   312                 raise Exc
       
   313 
       
   314         x = [Bad()]
       
   315         y = [Bad()]
       
   316 
       
   317         for op in opmap["eq"]:
       
   318             self.assertRaises(Exc, op, x, y)
       
   319 
       
   320     def test_goodentry(self):
       
   321         # This test exercises the final call to PyObject_RichCompare()
       
   322         # in Objects/listobject.c::list_richcompare()
       
   323         class Good:
       
   324             def __lt__(self, other):
       
   325                 return True
       
   326 
       
   327         x = [Good()]
       
   328         y = [Good()]
       
   329 
       
   330         for op in opmap["lt"]:
       
   331             self.assertIs(op(x, y), True)
       
   332 
       
   333 def test_main():
       
   334     test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest)
       
   335 
       
   336 if __name__ == "__main__":
       
   337     test_main()