(root)/
Python-3.12.0/
Lib/
test/
test_richcmp.py
       1  # Tests for rich comparisons
       2  
       3  import unittest
       4  from test import support
       5  
       6  import operator
       7  
       8  class ESC[4;38;5;81mNumber:
       9  
      10      def __init__(self, x):
      11          self.x = x
      12  
      13      def __lt__(self, other):
      14          return self.x < other
      15  
      16      def __le__(self, other):
      17          return self.x <= other
      18  
      19      def __eq__(self, other):
      20          return self.x == other
      21  
      22      def __ne__(self, other):
      23          return self.x != other
      24  
      25      def __gt__(self, other):
      26          return self.x > other
      27  
      28      def __ge__(self, other):
      29          return self.x >= other
      30  
      31      def __cmp__(self, other):
      32          raise support.TestFailed("Number.__cmp__() should not be called")
      33  
      34      def __repr__(self):
      35          return "Number(%r)" % (self.x, )
      36  
      37  class ESC[4;38;5;81mVector:
      38  
      39      def __init__(self, data):
      40          self.data = data
      41  
      42      def __len__(self):
      43          return len(self.data)
      44  
      45      def __getitem__(self, i):
      46          return self.data[i]
      47  
      48      def __setitem__(self, i, v):
      49          self.data[i] = v
      50  
      51      __hash__ = None # Vectors cannot be hashed
      52  
      53      def __bool__(self):
      54          raise TypeError("Vectors cannot be used in Boolean contexts")
      55  
      56      def __cmp__(self, other):
      57          raise support.TestFailed("Vector.__cmp__() should not be called")
      58  
      59      def __repr__(self):
      60          return "Vector(%r)" % (self.data, )
      61  
      62      def __lt__(self, other):
      63          return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
      64  
      65      def __le__(self, other):
      66          return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
      67  
      68      def __eq__(self, other):
      69          return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
      70  
      71      def __ne__(self, other):
      72          return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
      73  
      74      def __gt__(self, other):
      75          return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
      76  
      77      def __ge__(self, other):
      78          return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
      79  
      80      def __cast(self, other):
      81          if isinstance(other, Vector):
      82              other = other.data
      83          if len(self.data) != len(other):
      84              raise ValueError("Cannot compare vectors of different length")
      85          return other
      86  
      87  opmap = {
      88      "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
      89      "le": (lambda a,b: a<=b, operator.le, operator.__le__),
      90      "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
      91      "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
      92      "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
      93      "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
      94  }
      95  
      96  class ESC[4;38;5;81mVectorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      97  
      98      def checkfail(self, error, opname, *args):
      99          for op in opmap[opname]:
     100              self.assertRaises(error, op, *args)
     101  
     102      def checkequal(self, opname, a, b, expres):
     103          for op in opmap[opname]:
     104              realres = op(a, b)
     105              # can't use assertEqual(realres, expres) here
     106              self.assertEqual(len(realres), len(expres))
     107              for i in range(len(realres)):
     108                  # results are bool, so we can use "is" here
     109                  self.assertTrue(realres[i] is expres[i])
     110  
     111      def test_mixed(self):
     112          # check that comparisons involving Vector objects
     113          # which return rich results (i.e. Vectors with itemwise
     114          # comparison results) work
     115          a = Vector(range(2))
     116          b = Vector(range(3))
     117          # all comparisons should fail for different length
     118          for opname in opmap:
     119              self.checkfail(ValueError, opname, a, b)
     120  
     121          a = list(range(5))
     122          b = 5 * [2]
     123          # try mixed arguments (but not (a, b) as that won't return a bool vector)
     124          args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
     125          for (a, b) in args:
     126              self.checkequal("lt", a, b, [True,  True,  False, False, False])
     127              self.checkequal("le", a, b, [True,  True,  True,  False, False])
     128              self.checkequal("eq", a, b, [False, False, True,  False, False])
     129              self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
     130              self.checkequal("gt", a, b, [False, False, False, True,  True ])
     131              self.checkequal("ge", a, b, [False, False, True,  True,  True ])
     132  
     133              for ops in opmap.values():
     134                  for op in ops:
     135                      # calls __bool__, which should fail
     136                      self.assertRaises(TypeError, bool, op(a, b))
     137  
     138  class ESC[4;38;5;81mNumberTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     139  
     140      def test_basic(self):
     141          # Check that comparisons involving Number objects
     142          # give the same results give as comparing the
     143          # corresponding ints
     144          for a in range(3):
     145              for b in range(3):
     146                  for typea in (int, Number):
     147                      for typeb in (int, Number):
     148                          if typea==typeb==int:
     149                              continue # the combination int, int is useless
     150                          ta = typea(a)
     151                          tb = typeb(b)
     152                          for ops in opmap.values():
     153                              for op in ops:
     154                                  realoutcome = op(a, b)
     155                                  testoutcome = op(ta, tb)
     156                                  self.assertEqual(realoutcome, testoutcome)
     157  
     158      def checkvalue(self, opname, a, b, expres):
     159          for typea in (int, Number):
     160              for typeb in (int, Number):
     161                  ta = typea(a)
     162                  tb = typeb(b)
     163                  for op in opmap[opname]:
     164                      realres = op(ta, tb)
     165                      realres = getattr(realres, "x", realres)
     166                      self.assertTrue(realres is expres)
     167  
     168      def test_values(self):
     169          # check all operators and all comparison results
     170          self.checkvalue("lt", 0, 0, False)
     171          self.checkvalue("le", 0, 0, True )
     172          self.checkvalue("eq", 0, 0, True )
     173          self.checkvalue("ne", 0, 0, False)
     174          self.checkvalue("gt", 0, 0, False)
     175          self.checkvalue("ge", 0, 0, True )
     176  
     177          self.checkvalue("lt", 0, 1, True )
     178          self.checkvalue("le", 0, 1, True )
     179          self.checkvalue("eq", 0, 1, False)
     180          self.checkvalue("ne", 0, 1, True )
     181          self.checkvalue("gt", 0, 1, False)
     182          self.checkvalue("ge", 0, 1, False)
     183  
     184          self.checkvalue("lt", 1, 0, False)
     185          self.checkvalue("le", 1, 0, False)
     186          self.checkvalue("eq", 1, 0, False)
     187          self.checkvalue("ne", 1, 0, True )
     188          self.checkvalue("gt", 1, 0, True )
     189          self.checkvalue("ge", 1, 0, True )
     190  
     191  class ESC[4;38;5;81mMiscTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     192  
     193      def test_misbehavin(self):
     194          class ESC[4;38;5;81mMisb:
     195              def __lt__(self_, other): return 0
     196              def __gt__(self_, other): return 0
     197              def __eq__(self_, other): return 0
     198              def __le__(self_, other): self.fail("This shouldn't happen")
     199              def __ge__(self_, other): self.fail("This shouldn't happen")
     200              def __ne__(self_, other): self.fail("This shouldn't happen")
     201          a = Misb()
     202          b = Misb()
     203          self.assertEqual(a<b, 0)
     204          self.assertEqual(a==b, 0)
     205          self.assertEqual(a>b, 0)
     206  
     207      def test_not(self):
     208          # Check that exceptions in __bool__ are properly
     209          # propagated by the not operator
     210          import operator
     211          class ESC[4;38;5;81mExc(ESC[4;38;5;149mException):
     212              pass
     213          class ESC[4;38;5;81mBad:
     214              def __bool__(self):
     215                  raise Exc
     216  
     217          def do(bad):
     218              not bad
     219  
     220          for func in (do, operator.not_):
     221              self.assertRaises(Exc, func, Bad())
     222  
     223      @support.no_tracing
     224      def test_recursion(self):
     225          # Check that comparison for recursive objects fails gracefully
     226          from collections import UserList
     227          a = UserList()
     228          b = UserList()
     229          a.append(b)
     230          b.append(a)
     231          self.assertRaises(RecursionError, operator.eq, a, b)
     232          self.assertRaises(RecursionError, operator.ne, a, b)
     233          self.assertRaises(RecursionError, operator.lt, a, b)
     234          self.assertRaises(RecursionError, operator.le, a, b)
     235          self.assertRaises(RecursionError, operator.gt, a, b)
     236          self.assertRaises(RecursionError, operator.ge, a, b)
     237  
     238          b.append(17)
     239          # Even recursive lists of different lengths are different,
     240          # but they cannot be ordered
     241          self.assertTrue(not (a == b))
     242          self.assertTrue(a != b)
     243          self.assertRaises(RecursionError, operator.lt, a, b)
     244          self.assertRaises(RecursionError, operator.le, a, b)
     245          self.assertRaises(RecursionError, operator.gt, a, b)
     246          self.assertRaises(RecursionError, operator.ge, a, b)
     247          a.append(17)
     248          self.assertRaises(RecursionError, operator.eq, a, b)
     249          self.assertRaises(RecursionError, operator.ne, a, b)
     250          a.insert(0, 11)
     251          b.insert(0, 12)
     252          self.assertTrue(not (a == b))
     253          self.assertTrue(a != b)
     254          self.assertTrue(a < b)
     255  
     256      def test_exception_message(self):
     257          class ESC[4;38;5;81mSpam:
     258              pass
     259  
     260          tests = [
     261              (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"),
     262              (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"),
     263              (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"),
     264              (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"),
     265              (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"),
     266              (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"),
     267              (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"),
     268              (lambda: 42 < [], r"'<' .* of 'int' and 'list'"),
     269              (lambda: () > [], r"'>' .* of 'tuple' and 'list'"),
     270              (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"),
     271              (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"),
     272              (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"),
     273              (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"),
     274          ]
     275          for i, test in enumerate(tests):
     276              with self.subTest(test=i):
     277                  with self.assertRaisesRegex(TypeError, test[1]):
     278                      test[0]()
     279  
     280  
     281  class ESC[4;38;5;81mDictTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     282  
     283      def test_dicts(self):
     284          # Verify that __eq__ and __ne__ work for dicts even if the keys and
     285          # values don't support anything other than __eq__ and __ne__ (and
     286          # __hash__).  Complex numbers are a fine example of that.
     287          import random
     288          imag1a = {}
     289          for i in range(50):
     290              imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
     291          items = list(imag1a.items())
     292          random.shuffle(items)
     293          imag1b = {}
     294          for k, v in items:
     295              imag1b[k] = v
     296          imag2 = imag1b.copy()
     297          imag2[k] = v + 1.0
     298          self.assertEqual(imag1a, imag1a)
     299          self.assertEqual(imag1a, imag1b)
     300          self.assertEqual(imag2, imag2)
     301          self.assertTrue(imag1a != imag2)
     302          for opname in ("lt", "le", "gt", "ge"):
     303              for op in opmap[opname]:
     304                  self.assertRaises(TypeError, op, imag1a, imag2)
     305  
     306  class ESC[4;38;5;81mListTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     307  
     308      def test_coverage(self):
     309          # exercise all comparisons for lists
     310          x = [42]
     311          self.assertIs(x<x, False)
     312          self.assertIs(x<=x, True)
     313          self.assertIs(x==x, True)
     314          self.assertIs(x!=x, False)
     315          self.assertIs(x>x, False)
     316          self.assertIs(x>=x, True)
     317          y = [42, 42]
     318          self.assertIs(x<y, True)
     319          self.assertIs(x<=y, True)
     320          self.assertIs(x==y, False)
     321          self.assertIs(x!=y, True)
     322          self.assertIs(x>y, False)
     323          self.assertIs(x>=y, False)
     324  
     325      def test_badentry(self):
     326          # make sure that exceptions for item comparison are properly
     327          # propagated in list comparisons
     328          class ESC[4;38;5;81mExc(ESC[4;38;5;149mException):
     329              pass
     330          class ESC[4;38;5;81mBad:
     331              def __eq__(self, other):
     332                  raise Exc
     333  
     334          x = [Bad()]
     335          y = [Bad()]
     336  
     337          for op in opmap["eq"]:
     338              self.assertRaises(Exc, op, x, y)
     339  
     340      def test_goodentry(self):
     341          # This test exercises the final call to PyObject_RichCompare()
     342          # in Objects/listobject.c::list_richcompare()
     343          class ESC[4;38;5;81mGood:
     344              def __lt__(self, other):
     345                  return True
     346  
     347          x = [Good()]
     348          y = [Good()]
     349  
     350          for op in opmap["lt"]:
     351              self.assertIs(op(x, y), True)
     352  
     353  
     354  if __name__ == "__main__":
     355      unittest.main()