python (3.11.7)

(root)/
lib/
python3.11/
test/
test_richcmp.py
       1  # Tests for rich comparisons
       2  
       3  import unittest
       4  from test import support
       5  
       6  import operator
       7  
       8  class ESC[4;38;5;81mNumber:
       9  
      10      def __init__(self, x):
      11          self.x = x
      12  
      13      def __lt__(self, other):
      14          return self.x < other
      15  
      16      def __le__(self, other):
      17          return self.x <= other
      18  
      19      def __eq__(self, other):
      20          return self.x == other
      21  
      22      def __ne__(self, other):
      23          return self.x != other
      24  
      25      def __gt__(self, other):
      26          return self.x > other
      27  
      28      def __ge__(self, other):
      29          return self.x >= other
      30  
      31      def __cmp__(self, other):
      32          raise support.TestFailed("Number.__cmp__() should not be called")
      33  
      34      def __repr__(self):
      35          return "Number(%r)" % (self.x, )
      36  
      37  class ESC[4;38;5;81mVector:
      38  
      39      def __init__(self, data):
      40          self.data = data
      41  
      42      def __len__(self):
      43          return len(self.data)
      44  
      45      def __getitem__(self, i):
      46          return self.data[i]
      47  
      48      def __setitem__(self, i, v):
      49          self.data[i] = v
      50  
      51      __hash__ = None # Vectors cannot be hashed
      52  
      53      def __bool__(self):
      54          raise TypeError("Vectors cannot be used in Boolean contexts")
      55  
      56      def __cmp__(self, other):
      57          raise support.TestFailed("Vector.__cmp__() should not be called")
      58  
      59      def __repr__(self):
      60          return "Vector(%r)" % (self.data, )
      61  
      62      def __lt__(self, other):
      63          return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
      64  
      65      def __le__(self, other):
      66          return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
      67  
      68      def __eq__(self, other):
      69          return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
      70  
      71      def __ne__(self, other):
      72          return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
      73  
      74      def __gt__(self, other):
      75          return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
      76  
      77      def __ge__(self, other):
      78          return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
      79  
      80      def __cast(self, other):
      81          if isinstance(other, Vector):
      82              other = other.data
      83          if len(self.data) != len(other):
      84              raise ValueError("Cannot compare vectors of different length")
      85          return other
      86  
      87  opmap = {
      88      "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
      89      "le": (lambda a,b: a<=b, operator.le, operator.__le__),
      90      "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
      91      "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
      92      "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
      93      "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
      94  }
      95  
      96  class ESC[4;38;5;81mVectorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      97  
      98      def checkfail(self, error, opname, *args):
      99          for op in opmap[opname]:
     100              self.assertRaises(error, op, *args)
     101  
     102      def checkequal(self, opname, a, b, expres):
     103          for op in opmap[opname]:
     104              realres = op(a, b)
     105              # can't use assertEqual(realres, expres) here
     106              self.assertEqual(len(realres), len(expres))
     107              for i in range(len(realres)):
     108                  # results are bool, so we can use "is" here
     109                  self.assertTrue(realres[i] is expres[i])
     110  
     111      def test_mixed(self):
     112          # check that comparisons involving Vector objects
     113          # which return rich results (i.e. Vectors with itemwise
     114          # comparison results) work
     115          a = Vector(range(2))
     116          b = Vector(range(3))
     117          # all comparisons should fail for different length
     118          for opname in opmap:
     119              self.checkfail(ValueError, opname, a, b)
     120  
     121          a = list(range(5))
     122          b = 5 * [2]
     123          # try mixed arguments (but not (a, b) as that won't return a bool vector)
     124          args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
     125          for (a, b) in args:
     126              self.checkequal("lt", a, b, [True,  True,  False, False, False])
     127              self.checkequal("le", a, b, [True,  True,  True,  False, False])
     128              self.checkequal("eq", a, b, [False, False, True,  False, False])
     129              self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
     130              self.checkequal("gt", a, b, [False, False, False, True,  True ])
     131              self.checkequal("ge", a, b, [False, False, True,  True,  True ])
     132  
     133              for ops in opmap.values():
     134                  for op in ops:
     135                      # calls __bool__, which should fail
     136                      self.assertRaises(TypeError, bool, op(a, b))
     137  
     138  class ESC[4;38;5;81mNumberTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     139  
     140      def test_basic(self):
     141          # Check that comparisons involving Number objects
     142          # give the same results give as comparing the
     143          # corresponding ints
     144          for a in range(3):
     145              for b in range(3):
     146                  for typea in (int, Number):
     147                      for typeb in (int, Number):
     148                          if typea==typeb==int:
     149                              continue # the combination int, int is useless
     150                          ta = typea(a)
     151                          tb = typeb(b)
     152                          for ops in opmap.values():
     153                              for op in ops:
     154                                  realoutcome = op(a, b)
     155                                  testoutcome = op(ta, tb)
     156                                  self.assertEqual(realoutcome, testoutcome)
     157  
     158      def checkvalue(self, opname, a, b, expres):
     159          for typea in (int, Number):
     160              for typeb in (int, Number):
     161                  ta = typea(a)
     162                  tb = typeb(b)
     163                  for op in opmap[opname]:
     164                      realres = op(ta, tb)
     165                      realres = getattr(realres, "x", realres)
     166                      self.assertTrue(realres is expres)
     167  
     168      def test_values(self):
     169          # check all operators and all comparison results
     170          self.checkvalue("lt", 0, 0, False)
     171          self.checkvalue("le", 0, 0, True )
     172          self.checkvalue("eq", 0, 0, True )
     173          self.checkvalue("ne", 0, 0, False)
     174          self.checkvalue("gt", 0, 0, False)
     175          self.checkvalue("ge", 0, 0, True )
     176  
     177          self.checkvalue("lt", 0, 1, True )
     178          self.checkvalue("le", 0, 1, True )
     179          self.checkvalue("eq", 0, 1, False)
     180          self.checkvalue("ne", 0, 1, True )
     181          self.checkvalue("gt", 0, 1, False)
     182          self.checkvalue("ge", 0, 1, False)
     183  
     184          self.checkvalue("lt", 1, 0, False)
     185          self.checkvalue("le", 1, 0, False)
     186          self.checkvalue("eq", 1, 0, False)
     187          self.checkvalue("ne", 1, 0, True )
     188          self.checkvalue("gt", 1, 0, True )
     189          self.checkvalue("ge", 1, 0, True )
     190  
     191  class ESC[4;38;5;81mMiscTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     192  
     193      def test_misbehavin(self):
     194          class ESC[4;38;5;81mMisb:
     195              def __lt__(self_, other): return 0
     196              def __gt__(self_, other): return 0
     197              def __eq__(self_, other): return 0
     198              def __le__(self_, other): self.fail("This shouldn't happen")
     199              def __ge__(self_, other): self.fail("This shouldn't happen")
     200              def __ne__(self_, other): self.fail("This shouldn't happen")
     201          a = Misb()
     202          b = Misb()
     203          self.assertEqual(a<b, 0)
     204          self.assertEqual(a==b, 0)
     205          self.assertEqual(a>b, 0)
     206  
     207      def test_not(self):
     208          # Check that exceptions in __bool__ are properly
     209          # propagated by the not operator
     210          import operator
     211          class ESC[4;38;5;81mExc(ESC[4;38;5;149mException):
     212              pass
     213          class ESC[4;38;5;81mBad:
     214              def __bool__(self):
     215                  raise Exc
     216  
     217          def do(bad):
     218              not bad
     219  
     220          for func in (do, operator.not_):
     221              self.assertRaises(Exc, func, Bad())
     222  
     223      @support.no_tracing
     224      @support.infinite_recursion(25)
     225      def test_recursion(self):
     226          # Check that comparison for recursive objects fails gracefully
     227          from collections import UserList
     228          a = UserList()
     229          b = UserList()
     230          a.append(b)
     231          b.append(a)
     232          self.assertRaises(RecursionError, operator.eq, a, b)
     233          self.assertRaises(RecursionError, operator.ne, a, b)
     234          self.assertRaises(RecursionError, operator.lt, a, b)
     235          self.assertRaises(RecursionError, operator.le, a, b)
     236          self.assertRaises(RecursionError, operator.gt, a, b)
     237          self.assertRaises(RecursionError, operator.ge, a, b)
     238  
     239          b.append(17)
     240          # Even recursive lists of different lengths are different,
     241          # but they cannot be ordered
     242          self.assertTrue(not (a == b))
     243          self.assertTrue(a != b)
     244          self.assertRaises(RecursionError, operator.lt, a, b)
     245          self.assertRaises(RecursionError, operator.le, a, b)
     246          self.assertRaises(RecursionError, operator.gt, a, b)
     247          self.assertRaises(RecursionError, operator.ge, a, b)
     248          a.append(17)
     249          self.assertRaises(RecursionError, operator.eq, a, b)
     250          self.assertRaises(RecursionError, operator.ne, a, b)
     251          a.insert(0, 11)
     252          b.insert(0, 12)
     253          self.assertTrue(not (a == b))
     254          self.assertTrue(a != b)
     255          self.assertTrue(a < b)
     256  
     257      def test_exception_message(self):
     258          class ESC[4;38;5;81mSpam:
     259              pass
     260  
     261          tests = [
     262              (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"),
     263              (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"),
     264              (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"),
     265              (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"),
     266              (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"),
     267              (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"),
     268              (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"),
     269              (lambda: 42 < [], r"'<' .* of 'int' and 'list'"),
     270              (lambda: () > [], r"'>' .* of 'tuple' and 'list'"),
     271              (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"),
     272              (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"),
     273              (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"),
     274              (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"),
     275          ]
     276          for i, test in enumerate(tests):
     277              with self.subTest(test=i):
     278                  with self.assertRaisesRegex(TypeError, test[1]):
     279                      test[0]()
     280  
     281  
     282  class ESC[4;38;5;81mDictTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     283  
     284      def test_dicts(self):
     285          # Verify that __eq__ and __ne__ work for dicts even if the keys and
     286          # values don't support anything other than __eq__ and __ne__ (and
     287          # __hash__).  Complex numbers are a fine example of that.
     288          import random
     289          imag1a = {}
     290          for i in range(50):
     291              imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
     292          items = list(imag1a.items())
     293          random.shuffle(items)
     294          imag1b = {}
     295          for k, v in items:
     296              imag1b[k] = v
     297          imag2 = imag1b.copy()
     298          imag2[k] = v + 1.0
     299          self.assertEqual(imag1a, imag1a)
     300          self.assertEqual(imag1a, imag1b)
     301          self.assertEqual(imag2, imag2)
     302          self.assertTrue(imag1a != imag2)
     303          for opname in ("lt", "le", "gt", "ge"):
     304              for op in opmap[opname]:
     305                  self.assertRaises(TypeError, op, imag1a, imag2)
     306  
     307  class ESC[4;38;5;81mListTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     308  
     309      def test_coverage(self):
     310          # exercise all comparisons for lists
     311          x = [42]
     312          self.assertIs(x<x, False)
     313          self.assertIs(x<=x, True)
     314          self.assertIs(x==x, True)
     315          self.assertIs(x!=x, False)
     316          self.assertIs(x>x, False)
     317          self.assertIs(x>=x, True)
     318          y = [42, 42]
     319          self.assertIs(x<y, True)
     320          self.assertIs(x<=y, True)
     321          self.assertIs(x==y, False)
     322          self.assertIs(x!=y, True)
     323          self.assertIs(x>y, False)
     324          self.assertIs(x>=y, False)
     325  
     326      def test_badentry(self):
     327          # make sure that exceptions for item comparison are properly
     328          # propagated in list comparisons
     329          class ESC[4;38;5;81mExc(ESC[4;38;5;149mException):
     330              pass
     331          class ESC[4;38;5;81mBad:
     332              def __eq__(self, other):
     333                  raise Exc
     334  
     335          x = [Bad()]
     336          y = [Bad()]
     337  
     338          for op in opmap["eq"]:
     339              self.assertRaises(Exc, op, x, y)
     340  
     341      def test_goodentry(self):
     342          # This test exercises the final call to PyObject_RichCompare()
     343          # in Objects/listobject.c::list_richcompare()
     344          class ESC[4;38;5;81mGood:
     345              def __lt__(self, other):
     346                  return True
     347  
     348          x = [Good()]
     349          y = [Good()]
     350  
     351          for op in opmap["lt"]:
     352              self.assertIs(op(x, y), True)
     353  
     354  
     355  if __name__ == "__main__":
     356      unittest.main()