1 # Tests for rich comparisons
2
3 import unittest
4 from test import support
5
6 import operator
7
8 class ESC[4;38;5;81mNumber:
9
10 def __init__(self, x):
11 self.x = x
12
13 def __lt__(self, other):
14 return self.x < other
15
16 def __le__(self, other):
17 return self.x <= other
18
19 def __eq__(self, other):
20 return self.x == other
21
22 def __ne__(self, other):
23 return self.x != other
24
25 def __gt__(self, other):
26 return self.x > other
27
28 def __ge__(self, other):
29 return self.x >= other
30
31 def __cmp__(self, other):
32 raise support.TestFailed("Number.__cmp__() should not be called")
33
34 def __repr__(self):
35 return "Number(%r)" % (self.x, )
36
37 class ESC[4;38;5;81mVector:
38
39 def __init__(self, data):
40 self.data = data
41
42 def __len__(self):
43 return len(self.data)
44
45 def __getitem__(self, i):
46 return self.data[i]
47
48 def __setitem__(self, i, v):
49 self.data[i] = v
50
51 __hash__ = None # Vectors cannot be hashed
52
53 def __bool__(self):
54 raise TypeError("Vectors cannot be used in Boolean contexts")
55
56 def __cmp__(self, other):
57 raise support.TestFailed("Vector.__cmp__() should not be called")
58
59 def __repr__(self):
60 return "Vector(%r)" % (self.data, )
61
62 def __lt__(self, other):
63 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
64
65 def __le__(self, other):
66 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
67
68 def __eq__(self, other):
69 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
70
71 def __ne__(self, other):
72 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
73
74 def __gt__(self, other):
75 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
76
77 def __ge__(self, other):
78 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
79
80 def __cast(self, other):
81 if isinstance(other, Vector):
82 other = other.data
83 if len(self.data) != len(other):
84 raise ValueError("Cannot compare vectors of different length")
85 return other
86
87 opmap = {
88 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
89 "le": (lambda a,b: a<=b, operator.le, operator.__le__),
90 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
91 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
92 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
93 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
94 }
95
96 class ESC[4;38;5;81mVectorTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
97
98 def checkfail(self, error, opname, *args):
99 for op in opmap[opname]:
100 self.assertRaises(error, op, *args)
101
102 def checkequal(self, opname, a, b, expres):
103 for op in opmap[opname]:
104 realres = op(a, b)
105 # can't use assertEqual(realres, expres) here
106 self.assertEqual(len(realres), len(expres))
107 for i in range(len(realres)):
108 # results are bool, so we can use "is" here
109 self.assertTrue(realres[i] is expres[i])
110
111 def test_mixed(self):
112 # check that comparisons involving Vector objects
113 # which return rich results (i.e. Vectors with itemwise
114 # comparison results) work
115 a = Vector(range(2))
116 b = Vector(range(3))
117 # all comparisons should fail for different length
118 for opname in opmap:
119 self.checkfail(ValueError, opname, a, b)
120
121 a = list(range(5))
122 b = 5 * [2]
123 # try mixed arguments (but not (a, b) as that won't return a bool vector)
124 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
125 for (a, b) in args:
126 self.checkequal("lt", a, b, [True, True, False, False, False])
127 self.checkequal("le", a, b, [True, True, True, False, False])
128 self.checkequal("eq", a, b, [False, False, True, False, False])
129 self.checkequal("ne", a, b, [True, True, False, True, True ])
130 self.checkequal("gt", a, b, [False, False, False, True, True ])
131 self.checkequal("ge", a, b, [False, False, True, True, True ])
132
133 for ops in opmap.values():
134 for op in ops:
135 # calls __bool__, which should fail
136 self.assertRaises(TypeError, bool, op(a, b))
137
138 class ESC[4;38;5;81mNumberTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
139
140 def test_basic(self):
141 # Check that comparisons involving Number objects
142 # give the same results give as comparing the
143 # corresponding ints
144 for a in range(3):
145 for b in range(3):
146 for typea in (int, Number):
147 for typeb in (int, Number):
148 if typea==typeb==int:
149 continue # the combination int, int is useless
150 ta = typea(a)
151 tb = typeb(b)
152 for ops in opmap.values():
153 for op in ops:
154 realoutcome = op(a, b)
155 testoutcome = op(ta, tb)
156 self.assertEqual(realoutcome, testoutcome)
157
158 def checkvalue(self, opname, a, b, expres):
159 for typea in (int, Number):
160 for typeb in (int, Number):
161 ta = typea(a)
162 tb = typeb(b)
163 for op in opmap[opname]:
164 realres = op(ta, tb)
165 realres = getattr(realres, "x", realres)
166 self.assertTrue(realres is expres)
167
168 def test_values(self):
169 # check all operators and all comparison results
170 self.checkvalue("lt", 0, 0, False)
171 self.checkvalue("le", 0, 0, True )
172 self.checkvalue("eq", 0, 0, True )
173 self.checkvalue("ne", 0, 0, False)
174 self.checkvalue("gt", 0, 0, False)
175 self.checkvalue("ge", 0, 0, True )
176
177 self.checkvalue("lt", 0, 1, True )
178 self.checkvalue("le", 0, 1, True )
179 self.checkvalue("eq", 0, 1, False)
180 self.checkvalue("ne", 0, 1, True )
181 self.checkvalue("gt", 0, 1, False)
182 self.checkvalue("ge", 0, 1, False)
183
184 self.checkvalue("lt", 1, 0, False)
185 self.checkvalue("le", 1, 0, False)
186 self.checkvalue("eq", 1, 0, False)
187 self.checkvalue("ne", 1, 0, True )
188 self.checkvalue("gt", 1, 0, True )
189 self.checkvalue("ge", 1, 0, True )
190
191 class ESC[4;38;5;81mMiscTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
192
193 def test_misbehavin(self):
194 class ESC[4;38;5;81mMisb:
195 def __lt__(self_, other): return 0
196 def __gt__(self_, other): return 0
197 def __eq__(self_, other): return 0
198 def __le__(self_, other): self.fail("This shouldn't happen")
199 def __ge__(self_, other): self.fail("This shouldn't happen")
200 def __ne__(self_, other): self.fail("This shouldn't happen")
201 a = Misb()
202 b = Misb()
203 self.assertEqual(a<b, 0)
204 self.assertEqual(a==b, 0)
205 self.assertEqual(a>b, 0)
206
207 def test_not(self):
208 # Check that exceptions in __bool__ are properly
209 # propagated by the not operator
210 import operator
211 class ESC[4;38;5;81mExc(ESC[4;38;5;149mException):
212 pass
213 class ESC[4;38;5;81mBad:
214 def __bool__(self):
215 raise Exc
216
217 def do(bad):
218 not bad
219
220 for func in (do, operator.not_):
221 self.assertRaises(Exc, func, Bad())
222
223 @support.no_tracing
224 def test_recursion(self):
225 # Check that comparison for recursive objects fails gracefully
226 from collections import UserList
227 a = UserList()
228 b = UserList()
229 a.append(b)
230 b.append(a)
231 self.assertRaises(RecursionError, operator.eq, a, b)
232 self.assertRaises(RecursionError, operator.ne, a, b)
233 self.assertRaises(RecursionError, operator.lt, a, b)
234 self.assertRaises(RecursionError, operator.le, a, b)
235 self.assertRaises(RecursionError, operator.gt, a, b)
236 self.assertRaises(RecursionError, operator.ge, a, b)
237
238 b.append(17)
239 # Even recursive lists of different lengths are different,
240 # but they cannot be ordered
241 self.assertTrue(not (a == b))
242 self.assertTrue(a != b)
243 self.assertRaises(RecursionError, operator.lt, a, b)
244 self.assertRaises(RecursionError, operator.le, a, b)
245 self.assertRaises(RecursionError, operator.gt, a, b)
246 self.assertRaises(RecursionError, operator.ge, a, b)
247 a.append(17)
248 self.assertRaises(RecursionError, operator.eq, a, b)
249 self.assertRaises(RecursionError, operator.ne, a, b)
250 a.insert(0, 11)
251 b.insert(0, 12)
252 self.assertTrue(not (a == b))
253 self.assertTrue(a != b)
254 self.assertTrue(a < b)
255
256 def test_exception_message(self):
257 class ESC[4;38;5;81mSpam:
258 pass
259
260 tests = [
261 (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"),
262 (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"),
263 (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"),
264 (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"),
265 (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"),
266 (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"),
267 (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"),
268 (lambda: 42 < [], r"'<' .* of 'int' and 'list'"),
269 (lambda: () > [], r"'>' .* of 'tuple' and 'list'"),
270 (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"),
271 (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"),
272 (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"),
273 (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"),
274 ]
275 for i, test in enumerate(tests):
276 with self.subTest(test=i):
277 with self.assertRaisesRegex(TypeError, test[1]):
278 test[0]()
279
280
281 class ESC[4;38;5;81mDictTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
282
283 def test_dicts(self):
284 # Verify that __eq__ and __ne__ work for dicts even if the keys and
285 # values don't support anything other than __eq__ and __ne__ (and
286 # __hash__). Complex numbers are a fine example of that.
287 import random
288 imag1a = {}
289 for i in range(50):
290 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
291 items = list(imag1a.items())
292 random.shuffle(items)
293 imag1b = {}
294 for k, v in items:
295 imag1b[k] = v
296 imag2 = imag1b.copy()
297 imag2[k] = v + 1.0
298 self.assertEqual(imag1a, imag1a)
299 self.assertEqual(imag1a, imag1b)
300 self.assertEqual(imag2, imag2)
301 self.assertTrue(imag1a != imag2)
302 for opname in ("lt", "le", "gt", "ge"):
303 for op in opmap[opname]:
304 self.assertRaises(TypeError, op, imag1a, imag2)
305
306 class ESC[4;38;5;81mListTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
307
308 def test_coverage(self):
309 # exercise all comparisons for lists
310 x = [42]
311 self.assertIs(x<x, False)
312 self.assertIs(x<=x, True)
313 self.assertIs(x==x, True)
314 self.assertIs(x!=x, False)
315 self.assertIs(x>x, False)
316 self.assertIs(x>=x, True)
317 y = [42, 42]
318 self.assertIs(x<y, True)
319 self.assertIs(x<=y, True)
320 self.assertIs(x==y, False)
321 self.assertIs(x!=y, True)
322 self.assertIs(x>y, False)
323 self.assertIs(x>=y, False)
324
325 def test_badentry(self):
326 # make sure that exceptions for item comparison are properly
327 # propagated in list comparisons
328 class ESC[4;38;5;81mExc(ESC[4;38;5;149mException):
329 pass
330 class ESC[4;38;5;81mBad:
331 def __eq__(self, other):
332 raise Exc
333
334 x = [Bad()]
335 y = [Bad()]
336
337 for op in opmap["eq"]:
338 self.assertRaises(Exc, op, x, y)
339
340 def test_goodentry(self):
341 # This test exercises the final call to PyObject_RichCompare()
342 # in Objects/listobject.c::list_richcompare()
343 class ESC[4;38;5;81mGood:
344 def __lt__(self, other):
345 return True
346
347 x = [Good()]
348 y = [Good()]
349
350 for op in opmap["lt"]:
351 self.assertIs(op(x, y), True)
352
353
354 if __name__ == "__main__":
355 unittest.main()