1 from test import support
2 import random
3 import unittest
4 from functools import cmp_to_key
5
6 verbose = support.verbose
7 nerrors = 0
8
9
10 def check(tag, expected, raw, compare=None):
11 global nerrors
12
13 if verbose:
14 print(" checking", tag)
15
16 orig = raw[:] # save input in case of error
17 if compare:
18 raw.sort(key=cmp_to_key(compare))
19 else:
20 raw.sort()
21
22 if len(expected) != len(raw):
23 print("error in", tag)
24 print("length mismatch;", len(expected), len(raw))
25 print(expected)
26 print(orig)
27 print(raw)
28 nerrors += 1
29 return
30
31 for i, good in enumerate(expected):
32 maybe = raw[i]
33 if good is not maybe:
34 print("error in", tag)
35 print("out of order at index", i, good, maybe)
36 print(expected)
37 print(orig)
38 print(raw)
39 nerrors += 1
40 return
41
42 class ESC[4;38;5;81mTestBase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
43 def testStressfully(self):
44 # Try a variety of sizes at and around powers of 2, and at powers of 10.
45 sizes = [0]
46 for power in range(1, 10):
47 n = 2 ** power
48 sizes.extend(range(n-1, n+2))
49 sizes.extend([10, 100, 1000])
50
51 class ESC[4;38;5;81mComplains(ESC[4;38;5;149mobject):
52 maybe_complain = True
53
54 def __init__(self, i):
55 self.i = i
56
57 def __lt__(self, other):
58 if Complains.maybe_complain and random.random() < 0.001:
59 if verbose:
60 print(" complaining at", self, other)
61 raise RuntimeError
62 return self.i < other.i
63
64 def __repr__(self):
65 return "Complains(%d)" % self.i
66
67 class ESC[4;38;5;81mStable(ESC[4;38;5;149mobject):
68 def __init__(self, key, i):
69 self.key = key
70 self.index = i
71
72 def __lt__(self, other):
73 return self.key < other.key
74
75 def __repr__(self):
76 return "Stable(%d, %d)" % (self.key, self.index)
77
78 for n in sizes:
79 x = list(range(n))
80 if verbose:
81 print("Testing size", n)
82
83 s = x[:]
84 check("identity", x, s)
85
86 s = x[:]
87 s.reverse()
88 check("reversed", x, s)
89
90 s = x[:]
91 random.shuffle(s)
92 check("random permutation", x, s)
93
94 y = x[:]
95 y.reverse()
96 s = x[:]
97 check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
98
99 if verbose:
100 print(" Checking against an insane comparison function.")
101 print(" If the implementation isn't careful, this may segfault.")
102 s = x[:]
103 s.sort(key=cmp_to_key(lambda a, b: int(random.random() * 3) - 1))
104 check("an insane function left some permutation", x, s)
105
106 if len(x) >= 2:
107 def bad_key(x):
108 raise RuntimeError
109 s = x[:]
110 self.assertRaises(RuntimeError, s.sort, key=bad_key)
111
112 x = [Complains(i) for i in x]
113 s = x[:]
114 random.shuffle(s)
115 Complains.maybe_complain = True
116 it_complained = False
117 try:
118 s.sort()
119 except RuntimeError:
120 it_complained = True
121 if it_complained:
122 Complains.maybe_complain = False
123 check("exception during sort left some permutation", x, s)
124
125 s = [Stable(random.randrange(10), i) for i in range(n)]
126 augmented = [(e, e.index) for e in s]
127 augmented.sort() # forced stable because ties broken by index
128 x = [e for e, i in augmented] # a stable sort of s
129 check("stability", x, s)
130
131 #==============================================================================
132
133 class ESC[4;38;5;81mTestBugs(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
134
135 def test_bug453523(self):
136 # bug 453523 -- list.sort() crasher.
137 # If this fails, the most likely outcome is a core dump.
138 # Mutations during a list sort should raise a ValueError.
139
140 class ESC[4;38;5;81mC:
141 def __lt__(self, other):
142 if L and random.random() < 0.75:
143 L.pop()
144 else:
145 L.append(3)
146 return random.random() < 0.5
147
148 L = [C() for i in range(50)]
149 self.assertRaises(ValueError, L.sort)
150
151 def test_undetected_mutation(self):
152 # Python 2.4a1 did not always detect mutation
153 memorywaster = []
154 for i in range(20):
155 def mutating_cmp(x, y):
156 L.append(3)
157 L.pop()
158 return (x > y) - (x < y)
159 L = [1,2]
160 self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
161 def mutating_cmp(x, y):
162 L.append(3)
163 del L[:]
164 return (x > y) - (x < y)
165 self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
166 memorywaster = [memorywaster]
167
168 #==============================================================================
169
170 class ESC[4;38;5;81mTestDecorateSortUndecorate(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
171
172 def test_decorated(self):
173 data = 'The quick Brown fox Jumped over The lazy Dog'.split()
174 copy = data[:]
175 random.shuffle(data)
176 data.sort(key=str.lower)
177 def my_cmp(x, y):
178 xlower, ylower = x.lower(), y.lower()
179 return (xlower > ylower) - (xlower < ylower)
180 copy.sort(key=cmp_to_key(my_cmp))
181
182 def test_baddecorator(self):
183 data = 'The quick Brown fox Jumped over The lazy Dog'.split()
184 self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
185
186 def test_stability(self):
187 data = [(random.randrange(100), i) for i in range(200)]
188 copy = data[:]
189 data.sort(key=lambda t: t[0]) # sort on the random first field
190 copy.sort() # sort using both fields
191 self.assertEqual(data, copy) # should get the same result
192
193 def test_key_with_exception(self):
194 # Verify that the wrapper has been removed
195 data = list(range(-2, 2))
196 dup = data[:]
197 self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
198 self.assertEqual(data, dup)
199
200 def test_key_with_mutation(self):
201 data = list(range(10))
202 def k(x):
203 del data[:]
204 data[:] = range(20)
205 return x
206 self.assertRaises(ValueError, data.sort, key=k)
207
208 def test_key_with_mutating_del(self):
209 data = list(range(10))
210 class ESC[4;38;5;81mSortKiller(ESC[4;38;5;149mobject):
211 def __init__(self, x):
212 pass
213 def __del__(self):
214 del data[:]
215 data[:] = range(20)
216 def __lt__(self, other):
217 return id(self) < id(other)
218 self.assertRaises(ValueError, data.sort, key=SortKiller)
219
220 def test_key_with_mutating_del_and_exception(self):
221 data = list(range(10))
222 ## dup = data[:]
223 class ESC[4;38;5;81mSortKiller(ESC[4;38;5;149mobject):
224 def __init__(self, x):
225 if x > 2:
226 raise RuntimeError
227 def __del__(self):
228 del data[:]
229 data[:] = list(range(20))
230 self.assertRaises(RuntimeError, data.sort, key=SortKiller)
231 ## major honking subtlety: we *can't* do:
232 ##
233 ## self.assertEqual(data, dup)
234 ##
235 ## because there is a reference to a SortKiller in the
236 ## traceback and by the time it dies we're outside the call to
237 ## .sort() and so the list protection gimmicks are out of
238 ## date (this cost some brain cells to figure out...).
239
240 def test_reverse(self):
241 data = list(range(100))
242 random.shuffle(data)
243 data.sort(reverse=True)
244 self.assertEqual(data, list(range(99,-1,-1)))
245
246 def test_reverse_stability(self):
247 data = [(random.randrange(100), i) for i in range(200)]
248 copy1 = data[:]
249 copy2 = data[:]
250 def my_cmp(x, y):
251 x0, y0 = x[0], y[0]
252 return (x0 > y0) - (x0 < y0)
253 def my_cmp_reversed(x, y):
254 x0, y0 = x[0], y[0]
255 return (y0 > x0) - (y0 < x0)
256 data.sort(key=cmp_to_key(my_cmp), reverse=True)
257 copy1.sort(key=cmp_to_key(my_cmp_reversed))
258 self.assertEqual(data, copy1)
259 copy2.sort(key=lambda x: x[0], reverse=True)
260 self.assertEqual(data, copy2)
261
262 #==============================================================================
263 def check_against_PyObject_RichCompareBool(self, L):
264 ## The idea here is to exploit the fact that unsafe_tuple_compare uses
265 ## PyObject_RichCompareBool for the second elements of tuples. So we have,
266 ## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])]
267 ## This will work as long as __eq__ => not __lt__ for all the objects in L,
268 ## which holds for all the types used below.
269 ##
270 ## Testing this way ensures that the optimized implementation remains consistent
271 ## with the naive implementation, even if changes are made to any of the
272 ## richcompares.
273 ##
274 ## This function tests sorting for three lists (it randomly shuffles each one):
275 ## 1. L
276 ## 2. [(x,) for x in L]
277 ## 3. [((x,),) for x in L]
278
279 random.seed(0)
280 random.shuffle(L)
281 L_1 = L[:]
282 L_2 = [(x,) for x in L]
283 L_3 = [((x,),) for x in L]
284 for L in [L_1, L_2, L_3]:
285 optimized = sorted(L)
286 reference = [y[1] for y in sorted([(0,x) for x in L])]
287 for (opt, ref) in zip(optimized, reference):
288 self.assertIs(opt, ref)
289 #note: not assertEqual! We want to ensure *identical* behavior.
290
291 class ESC[4;38;5;81mTestOptimizedCompares(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
292 def test_safe_object_compare(self):
293 heterogeneous_lists = [[0, 'foo'],
294 [0.0, 'foo'],
295 [('foo',), 'foo']]
296 for L in heterogeneous_lists:
297 self.assertRaises(TypeError, L.sort)
298 self.assertRaises(TypeError, [(x,) for x in L].sort)
299 self.assertRaises(TypeError, [((x,),) for x in L].sort)
300
301 float_int_lists = [[1,1.1],
302 [1<<70,1.1],
303 [1.1,1],
304 [1.1,1<<70]]
305 for L in float_int_lists:
306 check_against_PyObject_RichCompareBool(self, L)
307
308 def test_unsafe_object_compare(self):
309
310 # This test is by ppperry. It ensures that unsafe_object_compare is
311 # verifying ms->key_richcompare == tp->richcompare before comparing.
312
313 class ESC[4;38;5;81mWackyComparator(ESC[4;38;5;149mint):
314 def __lt__(self, other):
315 elem.__class__ = WackyList2
316 return int.__lt__(self, other)
317
318 class ESC[4;38;5;81mWackyList1(ESC[4;38;5;149mlist):
319 pass
320
321 class ESC[4;38;5;81mWackyList2(ESC[4;38;5;149mlist):
322 def __lt__(self, other):
323 raise ValueError
324
325 L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
326 elem = L[-1]
327 with self.assertRaises(ValueError):
328 L.sort()
329
330 L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
331 elem = L[-1]
332 with self.assertRaises(ValueError):
333 [(x,) for x in L].sort()
334
335 # The following test is also by ppperry. It ensures that
336 # unsafe_object_compare handles Py_NotImplemented appropriately.
337 class ESC[4;38;5;81mPointlessComparator:
338 def __lt__(self, other):
339 return NotImplemented
340 L = [PointlessComparator(), PointlessComparator()]
341 self.assertRaises(TypeError, L.sort)
342 self.assertRaises(TypeError, [(x,) for x in L].sort)
343
344 # The following tests go through various types that would trigger
345 # ms->key_compare = unsafe_object_compare
346 lists = [list(range(100)) + [(1<<70)],
347 [str(x) for x in range(100)] + ['\uffff'],
348 [bytes(x) for x in range(100)],
349 [cmp_to_key(lambda x,y: x<y)(x) for x in range(100)]]
350 for L in lists:
351 check_against_PyObject_RichCompareBool(self, L)
352
353 def test_unsafe_latin_compare(self):
354 check_against_PyObject_RichCompareBool(self, [str(x) for
355 x in range(100)])
356
357 def test_unsafe_long_compare(self):
358 check_against_PyObject_RichCompareBool(self, [x for
359 x in range(100)])
360
361 def test_unsafe_float_compare(self):
362 check_against_PyObject_RichCompareBool(self, [float(x) for
363 x in range(100)])
364
365 def test_unsafe_tuple_compare(self):
366 # This test was suggested by Tim Peters. It verifies that the tuple
367 # comparison respects the current tuple compare semantics, which do not
368 # guarantee that x < x <=> (x,) < (x,)
369 #
370 # Note that we don't have to put anything in tuples here, because
371 # the check function does a tuple test automatically.
372
373 check_against_PyObject_RichCompareBool(self, [float('nan')]*100)
374 check_against_PyObject_RichCompareBool(self, [float('nan') for
375 _ in range(100)])
376
377 def test_not_all_tuples(self):
378 self.assertRaises(TypeError, [(1.0, 1.0), (False, "A"), 6].sort)
379 self.assertRaises(TypeError, [('a', 1), (1, 'a')].sort)
380 self.assertRaises(TypeError, [(1, 'a'), ('a', 1)].sort)
381
382 def test_none_in_tuples(self):
383 expected = [(None, 1), (None, 2)]
384 actual = sorted([(None, 2), (None, 1)])
385 self.assertEqual(actual, expected)
386
387 #==============================================================================
388
389 if __name__ == "__main__":
390 unittest.main()