1 """
2 Test suite for socketserver.
3 """
4
5 import contextlib
6 import io
7 import os
8 import select
9 import signal
10 import socket
11 import threading
12 import unittest
13 import socketserver
14
15 import test.support
16 from test.support import reap_children, verbose
17 from test.support import os_helper
18 from test.support import socket_helper
19 from test.support import threading_helper
20
21
22 test.support.requires("network")
23 test.support.requires_working_socket(module=True)
24
25
26 TEST_STR = b"hello world\n"
27 HOST = socket_helper.HOST
28
29 HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
30 requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
31 'requires Unix sockets')
32 HAVE_FORKING = test.support.has_fork_support
33 requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
34
35 def signal_alarm(n):
36 """Call signal.alarm when it exists (i.e. not on Windows)."""
37 if hasattr(signal, 'alarm'):
38 signal.alarm(n)
39
40 # Remember real select() to avoid interferences with mocking
41 _real_select = select.select
42
43 def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
44 r, w, x = _real_select([sock], [], [], timeout)
45 if sock in r:
46 return sock.recv(n)
47 else:
48 raise RuntimeError("timed out on %r" % (sock,))
49
50
51 @test.support.requires_fork()
52 @contextlib.contextmanager
53 def simple_subprocess(testcase):
54 """Tests that a custom child process is not waited on (Issue 1540386)"""
55 pid = os.fork()
56 if pid == 0:
57 # Don't raise an exception; it would be caught by the test harness.
58 os._exit(72)
59 try:
60 yield None
61 except:
62 raise
63 finally:
64 test.support.wait_process(pid, exitcode=72)
65
66
67 class ESC[4;38;5;81mSocketServerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
68 """Test all socket servers."""
69
70 def setUp(self):
71 signal_alarm(60) # Kill deadlocks after 60 seconds.
72 self.port_seed = 0
73 self.test_files = []
74
75 def tearDown(self):
76 signal_alarm(0) # Didn't deadlock.
77 reap_children()
78
79 for fn in self.test_files:
80 try:
81 os.remove(fn)
82 except OSError:
83 pass
84 self.test_files[:] = []
85
86 def pickaddr(self, proto):
87 if proto == socket.AF_INET:
88 return (HOST, 0)
89 else:
90 # XXX: We need a way to tell AF_UNIX to pick its own name
91 # like AF_INET provides port==0.
92 fn = socket_helper.create_unix_domain_name()
93 self.test_files.append(fn)
94 return fn
95
96 def make_server(self, addr, svrcls, hdlrbase):
97 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msvrcls):
98 def handle_error(self, request, client_address):
99 self.close_request(request)
100 raise
101
102 class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149mhdlrbase):
103 def handle(self):
104 line = self.rfile.readline()
105 self.wfile.write(line)
106
107 if verbose: print("creating server")
108 try:
109 server = MyServer(addr, MyHandler)
110 except PermissionError as e:
111 # Issue 29184: cannot bind() a Unix socket on Android.
112 self.skipTest('Cannot create server (%s, %s): %s' %
113 (svrcls, addr, e))
114 self.assertEqual(server.server_address, server.socket.getsockname())
115 return server
116
117 @threading_helper.reap_threads
118 def run_server(self, svrcls, hdlrbase, testfunc):
119 server = self.make_server(self.pickaddr(svrcls.address_family),
120 svrcls, hdlrbase)
121 # We had the OS pick a port, so pull the real address out of
122 # the server.
123 addr = server.server_address
124 if verbose:
125 print("ADDR =", addr)
126 print("CLASS =", svrcls)
127
128 t = threading.Thread(
129 name='%s serving' % svrcls,
130 target=server.serve_forever,
131 # Short poll interval to make the test finish quickly.
132 # Time between requests is short enough that we won't wake
133 # up spuriously too many times.
134 kwargs={'poll_interval':0.01})
135 t.daemon = True # In case this function raises.
136 t.start()
137 if verbose: print("server running")
138 for i in range(3):
139 if verbose: print("test client", i)
140 testfunc(svrcls.address_family, addr)
141 if verbose: print("waiting for server")
142 server.shutdown()
143 t.join()
144 server.server_close()
145 self.assertEqual(-1, server.socket.fileno())
146 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
147 # bpo-31151: Check that ForkingMixIn.server_close() waits until
148 # all children completed
149 self.assertFalse(server.active_children)
150 if verbose: print("done")
151
152 def stream_examine(self, proto, addr):
153 with socket.socket(proto, socket.SOCK_STREAM) as s:
154 s.connect(addr)
155 s.sendall(TEST_STR)
156 buf = data = receive(s, 100)
157 while data and b'\n' not in buf:
158 data = receive(s, 100)
159 buf += data
160 self.assertEqual(buf, TEST_STR)
161
162 def dgram_examine(self, proto, addr):
163 with socket.socket(proto, socket.SOCK_DGRAM) as s:
164 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
165 s.bind(self.pickaddr(proto))
166 s.sendto(TEST_STR, addr)
167 buf = data = receive(s, 100)
168 while data and b'\n' not in buf:
169 data = receive(s, 100)
170 buf += data
171 self.assertEqual(buf, TEST_STR)
172
173 def test_TCPServer(self):
174 self.run_server(socketserver.TCPServer,
175 socketserver.StreamRequestHandler,
176 self.stream_examine)
177
178 def test_ThreadingTCPServer(self):
179 self.run_server(socketserver.ThreadingTCPServer,
180 socketserver.StreamRequestHandler,
181 self.stream_examine)
182
183 @requires_forking
184 def test_ForkingTCPServer(self):
185 with simple_subprocess(self):
186 self.run_server(socketserver.ForkingTCPServer,
187 socketserver.StreamRequestHandler,
188 self.stream_examine)
189
190 @requires_unix_sockets
191 def test_UnixStreamServer(self):
192 self.run_server(socketserver.UnixStreamServer,
193 socketserver.StreamRequestHandler,
194 self.stream_examine)
195
196 @requires_unix_sockets
197 def test_ThreadingUnixStreamServer(self):
198 self.run_server(socketserver.ThreadingUnixStreamServer,
199 socketserver.StreamRequestHandler,
200 self.stream_examine)
201
202 @requires_unix_sockets
203 @requires_forking
204 def test_ForkingUnixStreamServer(self):
205 with simple_subprocess(self):
206 self.run_server(socketserver.ForkingUnixStreamServer,
207 socketserver.StreamRequestHandler,
208 self.stream_examine)
209
210 def test_UDPServer(self):
211 self.run_server(socketserver.UDPServer,
212 socketserver.DatagramRequestHandler,
213 self.dgram_examine)
214
215 def test_ThreadingUDPServer(self):
216 self.run_server(socketserver.ThreadingUDPServer,
217 socketserver.DatagramRequestHandler,
218 self.dgram_examine)
219
220 @requires_forking
221 def test_ForkingUDPServer(self):
222 with simple_subprocess(self):
223 self.run_server(socketserver.ForkingUDPServer,
224 socketserver.DatagramRequestHandler,
225 self.dgram_examine)
226
227 @requires_unix_sockets
228 def test_UnixDatagramServer(self):
229 self.run_server(socketserver.UnixDatagramServer,
230 socketserver.DatagramRequestHandler,
231 self.dgram_examine)
232
233 @requires_unix_sockets
234 def test_ThreadingUnixDatagramServer(self):
235 self.run_server(socketserver.ThreadingUnixDatagramServer,
236 socketserver.DatagramRequestHandler,
237 self.dgram_examine)
238
239 @requires_unix_sockets
240 @requires_forking
241 def test_ForkingUnixDatagramServer(self):
242 self.run_server(socketserver.ForkingUnixDatagramServer,
243 socketserver.DatagramRequestHandler,
244 self.dgram_examine)
245
246 @threading_helper.reap_threads
247 def test_shutdown(self):
248 # Issue #2302: shutdown() should always succeed in making an
249 # other thread leave serve_forever().
250 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
251 pass
252
253 class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
254 pass
255
256 threads = []
257 for i in range(20):
258 s = MyServer((HOST, 0), MyHandler)
259 t = threading.Thread(
260 name='MyServer serving',
261 target=s.serve_forever,
262 kwargs={'poll_interval':0.01})
263 t.daemon = True # In case this function raises.
264 threads.append((t, s))
265 for t, s in threads:
266 t.start()
267 s.shutdown()
268 for t, s in threads:
269 t.join()
270 s.server_close()
271
272 def test_close_immediately(self):
273 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
274 pass
275
276 server = MyServer((HOST, 0), lambda: None)
277 server.server_close()
278
279 def test_tcpserver_bind_leak(self):
280 # Issue #22435: the server socket wouldn't be closed if bind()/listen()
281 # failed.
282 # Create many servers for which bind() will fail, to see if this result
283 # in FD exhaustion.
284 for i in range(1024):
285 with self.assertRaises(OverflowError):
286 socketserver.TCPServer((HOST, -1),
287 socketserver.StreamRequestHandler)
288
289 def test_context_manager(self):
290 with socketserver.TCPServer((HOST, 0),
291 socketserver.StreamRequestHandler) as server:
292 pass
293 self.assertEqual(-1, server.socket.fileno())
294
295
296 class ESC[4;38;5;81mErrorHandlerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
297 """Test that the servers pass normal exceptions from the handler to
298 handle_error(), and that exiting exceptions like SystemExit and
299 KeyboardInterrupt are not passed."""
300
301 def tearDown(self):
302 os_helper.unlink(os_helper.TESTFN)
303
304 def test_sync_handled(self):
305 BaseErrorTestServer(ValueError)
306 self.check_result(handled=True)
307
308 def test_sync_not_handled(self):
309 with self.assertRaises(SystemExit):
310 BaseErrorTestServer(SystemExit)
311 self.check_result(handled=False)
312
313 def test_threading_handled(self):
314 ThreadingErrorTestServer(ValueError)
315 self.check_result(handled=True)
316
317 def test_threading_not_handled(self):
318 with threading_helper.catch_threading_exception() as cm:
319 ThreadingErrorTestServer(SystemExit)
320 self.check_result(handled=False)
321
322 self.assertIs(cm.exc_type, SystemExit)
323
324 @requires_forking
325 def test_forking_handled(self):
326 ForkingErrorTestServer(ValueError)
327 self.check_result(handled=True)
328
329 @requires_forking
330 def test_forking_not_handled(self):
331 ForkingErrorTestServer(SystemExit)
332 self.check_result(handled=False)
333
334 def check_result(self, handled):
335 with open(os_helper.TESTFN) as log:
336 expected = 'Handler called\n' + 'Error handled\n' * handled
337 self.assertEqual(log.read(), expected)
338
339
340 class ESC[4;38;5;81mBaseErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
341 def __init__(self, exception):
342 self.exception = exception
343 super().__init__((HOST, 0), BadHandler)
344 with socket.create_connection(self.server_address):
345 pass
346 try:
347 self.handle_request()
348 finally:
349 self.server_close()
350 self.wait_done()
351
352 def handle_error(self, request, client_address):
353 with open(os_helper.TESTFN, 'a') as log:
354 log.write('Error handled\n')
355
356 def wait_done(self):
357 pass
358
359
360 class ESC[4;38;5;81mBadHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mBaseRequestHandler):
361 def handle(self):
362 with open(os_helper.TESTFN, 'a') as log:
363 log.write('Handler called\n')
364 raise self.server.exception('Test error')
365
366
367 class ESC[4;38;5;81mThreadingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn,
368 ESC[4;38;5;149mBaseErrorTestServer):
369 def __init__(self, *pos, **kw):
370 self.done = threading.Event()
371 super().__init__(*pos, **kw)
372
373 def shutdown_request(self, *pos, **kw):
374 super().shutdown_request(*pos, **kw)
375 self.done.set()
376
377 def wait_done(self):
378 self.done.wait()
379
380
381 if HAVE_FORKING:
382 class ESC[4;38;5;81mForkingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn, ESC[4;38;5;149mBaseErrorTestServer):
383 pass
384
385
386 class ESC[4;38;5;81mSocketWriterTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
387 def test_basics(self):
388 class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
389 def handle(self):
390 self.server.wfile = self.wfile
391 self.server.wfile_fileno = self.wfile.fileno()
392 self.server.request_fileno = self.request.fileno()
393
394 server = socketserver.TCPServer((HOST, 0), Handler)
395 self.addCleanup(server.server_close)
396 s = socket.socket(
397 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
398 with s:
399 s.connect(server.server_address)
400 server.handle_request()
401 self.assertIsInstance(server.wfile, io.BufferedIOBase)
402 self.assertEqual(server.wfile_fileno, server.request_fileno)
403
404 def test_write(self):
405 # Test that wfile.write() sends data immediately, and that it does
406 # not truncate sends when interrupted by a Unix signal
407 pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
408
409 class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
410 def handle(self):
411 self.server.sent1 = self.wfile.write(b'write data\n')
412 # Should be sent immediately, without requiring flush()
413 self.server.received = self.rfile.readline()
414 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
415 self.server.sent2 = self.wfile.write(big_chunk)
416
417 server = socketserver.TCPServer((HOST, 0), Handler)
418 self.addCleanup(server.server_close)
419 interrupted = threading.Event()
420
421 def signal_handler(signum, frame):
422 interrupted.set()
423
424 original = signal.signal(signal.SIGUSR1, signal_handler)
425 self.addCleanup(signal.signal, signal.SIGUSR1, original)
426 response1 = None
427 received2 = None
428 main_thread = threading.get_ident()
429
430 def run_client():
431 s = socket.socket(server.address_family, socket.SOCK_STREAM,
432 socket.IPPROTO_TCP)
433 with s, s.makefile('rb') as reader:
434 s.connect(server.server_address)
435 nonlocal response1
436 response1 = reader.readline()
437 s.sendall(b'client response\n')
438
439 reader.read(100)
440 # The main thread should now be blocking in a send() syscall.
441 # But in theory, it could get interrupted by other signals,
442 # and then retried. So keep sending the signal in a loop, in
443 # case an earlier signal happens to be delivered at an
444 # inconvenient moment.
445 while True:
446 pthread_kill(main_thread, signal.SIGUSR1)
447 if interrupted.wait(timeout=float(1)):
448 break
449 nonlocal received2
450 received2 = len(reader.read())
451
452 background = threading.Thread(target=run_client)
453 background.start()
454 server.handle_request()
455 background.join()
456 self.assertEqual(server.sent1, len(response1))
457 self.assertEqual(response1, b'write data\n')
458 self.assertEqual(server.received, b'client response\n')
459 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
460 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
461
462
463 class ESC[4;38;5;81mMiscTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
464
465 def test_all(self):
466 # objects defined in the module should be in __all__
467 expected = []
468 for name in dir(socketserver):
469 if not name.startswith('_'):
470 mod_object = getattr(socketserver, name)
471 if getattr(mod_object, '__module__', None) == 'socketserver':
472 expected.append(name)
473 self.assertCountEqual(socketserver.__all__, expected)
474
475 def test_shutdown_request_called_if_verify_request_false(self):
476 # Issue #26309: BaseServer should call shutdown_request even if
477 # verify_request is False
478
479 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
480 def verify_request(self, request, client_address):
481 return False
482
483 shutdown_called = 0
484 def shutdown_request(self, request):
485 self.shutdown_called += 1
486 socketserver.TCPServer.shutdown_request(self, request)
487
488 server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
489 s = socket.socket(server.address_family, socket.SOCK_STREAM)
490 s.connect(server.server_address)
491 s.close()
492 server.handle_request()
493 self.assertEqual(server.shutdown_called, 1)
494 server.server_close()
495
496 def test_threads_reaped(self):
497 """
498 In #37193, users reported a memory leak
499 due to the saving of every request thread. Ensure that
500 not all threads are kept forever.
501 """
502 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
503 pass
504
505 server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
506 for n in range(10):
507 with socket.create_connection(server.server_address):
508 server.handle_request()
509 self.assertLess(len(server._threads), 10)
510 server.server_close()
511
512
513 if __name__ == "__main__":
514 unittest.main()