1 """
2 Test suite for socketserver.
3 """
4
5 import contextlib
6 import io
7 import os
8 import select
9 import signal
10 import socket
11 import tempfile
12 import threading
13 import unittest
14 import socketserver
15
16 import test.support
17 from test.support import reap_children, verbose
18 from test.support import os_helper
19 from test.support import socket_helper
20 from test.support import threading_helper
21
22
23 test.support.requires("network")
24 test.support.requires_working_socket(module=True)
25
26
27 TEST_STR = b"hello world\n"
28 HOST = socket_helper.HOST
29
30 HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
31 requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
32 'requires Unix sockets')
33 HAVE_FORKING = test.support.has_fork_support
34 requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
35
36 # Remember real select() to avoid interferences with mocking
37 _real_select = select.select
38
39 def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
40 r, w, x = _real_select([sock], [], [], timeout)
41 if sock in r:
42 return sock.recv(n)
43 else:
44 raise RuntimeError("timed out on %r" % (sock,))
45
46 if HAVE_UNIX_SOCKETS and HAVE_FORKING:
47 class ESC[4;38;5;81mForkingUnixStreamServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn,
48 ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixStreamServer):
49 pass
50
51 class ESC[4;38;5;81mForkingUnixDatagramServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn,
52 ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixDatagramServer):
53 pass
54
55
56 @contextlib.contextmanager
57 def simple_subprocess(testcase):
58 """Tests that a custom child process is not waited on (Issue 1540386)"""
59 pid = os.fork()
60 if pid == 0:
61 # Don't raise an exception; it would be caught by the test harness.
62 os._exit(72)
63 try:
64 yield None
65 except:
66 raise
67 finally:
68 test.support.wait_process(pid, exitcode=72)
69
70
71 class ESC[4;38;5;81mSocketServerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
72 """Test all socket servers."""
73
74 def setUp(self):
75 self.port_seed = 0
76 self.test_files = []
77
78 def tearDown(self):
79 reap_children()
80
81 for fn in self.test_files:
82 try:
83 os.remove(fn)
84 except OSError:
85 pass
86 self.test_files[:] = []
87
88 def pickaddr(self, proto):
89 if proto == socket.AF_INET:
90 return (HOST, 0)
91 else:
92 # XXX: We need a way to tell AF_UNIX to pick its own name
93 # like AF_INET provides port==0.
94 dir = None
95 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
96 self.test_files.append(fn)
97 return fn
98
99 def make_server(self, addr, svrcls, hdlrbase):
100 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msvrcls):
101 def handle_error(self, request, client_address):
102 self.close_request(request)
103 raise
104
105 class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149mhdlrbase):
106 def handle(self):
107 line = self.rfile.readline()
108 self.wfile.write(line)
109
110 if verbose: print("creating server")
111 try:
112 server = MyServer(addr, MyHandler)
113 except PermissionError as e:
114 # Issue 29184: cannot bind() a Unix socket on Android.
115 self.skipTest('Cannot create server (%s, %s): %s' %
116 (svrcls, addr, e))
117 self.assertEqual(server.server_address, server.socket.getsockname())
118 return server
119
120 @threading_helper.reap_threads
121 def run_server(self, svrcls, hdlrbase, testfunc):
122 server = self.make_server(self.pickaddr(svrcls.address_family),
123 svrcls, hdlrbase)
124 # We had the OS pick a port, so pull the real address out of
125 # the server.
126 addr = server.server_address
127 if verbose:
128 print("ADDR =", addr)
129 print("CLASS =", svrcls)
130
131 t = threading.Thread(
132 name='%s serving' % svrcls,
133 target=server.serve_forever,
134 # Short poll interval to make the test finish quickly.
135 # Time between requests is short enough that we won't wake
136 # up spuriously too many times.
137 kwargs={'poll_interval':0.01})
138 t.daemon = True # In case this function raises.
139 t.start()
140 if verbose: print("server running")
141 for i in range(3):
142 if verbose: print("test client", i)
143 testfunc(svrcls.address_family, addr)
144 if verbose: print("waiting for server")
145 server.shutdown()
146 t.join()
147 server.server_close()
148 self.assertEqual(-1, server.socket.fileno())
149 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
150 # bpo-31151: Check that ForkingMixIn.server_close() waits until
151 # all children completed
152 self.assertFalse(server.active_children)
153 if verbose: print("done")
154
155 def stream_examine(self, proto, addr):
156 with socket.socket(proto, socket.SOCK_STREAM) as s:
157 s.connect(addr)
158 s.sendall(TEST_STR)
159 buf = data = receive(s, 100)
160 while data and b'\n' not in buf:
161 data = receive(s, 100)
162 buf += data
163 self.assertEqual(buf, TEST_STR)
164
165 def dgram_examine(self, proto, addr):
166 with socket.socket(proto, socket.SOCK_DGRAM) as s:
167 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
168 s.bind(self.pickaddr(proto))
169 s.sendto(TEST_STR, addr)
170 buf = data = receive(s, 100)
171 while data and b'\n' not in buf:
172 data = receive(s, 100)
173 buf += data
174 self.assertEqual(buf, TEST_STR)
175
176 def test_TCPServer(self):
177 self.run_server(socketserver.TCPServer,
178 socketserver.StreamRequestHandler,
179 self.stream_examine)
180
181 def test_ThreadingTCPServer(self):
182 self.run_server(socketserver.ThreadingTCPServer,
183 socketserver.StreamRequestHandler,
184 self.stream_examine)
185
186 @requires_forking
187 def test_ForkingTCPServer(self):
188 with simple_subprocess(self):
189 self.run_server(socketserver.ForkingTCPServer,
190 socketserver.StreamRequestHandler,
191 self.stream_examine)
192
193 @requires_unix_sockets
194 def test_UnixStreamServer(self):
195 self.run_server(socketserver.UnixStreamServer,
196 socketserver.StreamRequestHandler,
197 self.stream_examine)
198
199 @requires_unix_sockets
200 def test_ThreadingUnixStreamServer(self):
201 self.run_server(socketserver.ThreadingUnixStreamServer,
202 socketserver.StreamRequestHandler,
203 self.stream_examine)
204
205 @requires_unix_sockets
206 @requires_forking
207 def test_ForkingUnixStreamServer(self):
208 with simple_subprocess(self):
209 self.run_server(ForkingUnixStreamServer,
210 socketserver.StreamRequestHandler,
211 self.stream_examine)
212
213 def test_UDPServer(self):
214 self.run_server(socketserver.UDPServer,
215 socketserver.DatagramRequestHandler,
216 self.dgram_examine)
217
218 def test_ThreadingUDPServer(self):
219 self.run_server(socketserver.ThreadingUDPServer,
220 socketserver.DatagramRequestHandler,
221 self.dgram_examine)
222
223 @requires_forking
224 def test_ForkingUDPServer(self):
225 with simple_subprocess(self):
226 self.run_server(socketserver.ForkingUDPServer,
227 socketserver.DatagramRequestHandler,
228 self.dgram_examine)
229
230 @requires_unix_sockets
231 def test_UnixDatagramServer(self):
232 self.run_server(socketserver.UnixDatagramServer,
233 socketserver.DatagramRequestHandler,
234 self.dgram_examine)
235
236 @requires_unix_sockets
237 def test_ThreadingUnixDatagramServer(self):
238 self.run_server(socketserver.ThreadingUnixDatagramServer,
239 socketserver.DatagramRequestHandler,
240 self.dgram_examine)
241
242 @requires_unix_sockets
243 @requires_forking
244 def test_ForkingUnixDatagramServer(self):
245 self.run_server(ForkingUnixDatagramServer,
246 socketserver.DatagramRequestHandler,
247 self.dgram_examine)
248
249 @threading_helper.reap_threads
250 def test_shutdown(self):
251 # Issue #2302: shutdown() should always succeed in making an
252 # other thread leave serve_forever().
253 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
254 pass
255
256 class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
257 pass
258
259 threads = []
260 for i in range(20):
261 s = MyServer((HOST, 0), MyHandler)
262 t = threading.Thread(
263 name='MyServer serving',
264 target=s.serve_forever,
265 kwargs={'poll_interval':0.01})
266 t.daemon = True # In case this function raises.
267 threads.append((t, s))
268 for t, s in threads:
269 t.start()
270 s.shutdown()
271 for t, s in threads:
272 t.join()
273 s.server_close()
274
275 def test_close_immediately(self):
276 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
277 pass
278
279 server = MyServer((HOST, 0), lambda: None)
280 server.server_close()
281
282 def test_tcpserver_bind_leak(self):
283 # Issue #22435: the server socket wouldn't be closed if bind()/listen()
284 # failed.
285 # Create many servers for which bind() will fail, to see if this result
286 # in FD exhaustion.
287 for i in range(1024):
288 with self.assertRaises(OverflowError):
289 socketserver.TCPServer((HOST, -1),
290 socketserver.StreamRequestHandler)
291
292 def test_context_manager(self):
293 with socketserver.TCPServer((HOST, 0),
294 socketserver.StreamRequestHandler) as server:
295 pass
296 self.assertEqual(-1, server.socket.fileno())
297
298
299 class ESC[4;38;5;81mErrorHandlerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
300 """Test that the servers pass normal exceptions from the handler to
301 handle_error(), and that exiting exceptions like SystemExit and
302 KeyboardInterrupt are not passed."""
303
304 def tearDown(self):
305 os_helper.unlink(os_helper.TESTFN)
306
307 def test_sync_handled(self):
308 BaseErrorTestServer(ValueError)
309 self.check_result(handled=True)
310
311 def test_sync_not_handled(self):
312 with self.assertRaises(SystemExit):
313 BaseErrorTestServer(SystemExit)
314 self.check_result(handled=False)
315
316 def test_threading_handled(self):
317 ThreadingErrorTestServer(ValueError)
318 self.check_result(handled=True)
319
320 def test_threading_not_handled(self):
321 with threading_helper.catch_threading_exception() as cm:
322 ThreadingErrorTestServer(SystemExit)
323 self.check_result(handled=False)
324
325 self.assertIs(cm.exc_type, SystemExit)
326
327 @requires_forking
328 def test_forking_handled(self):
329 ForkingErrorTestServer(ValueError)
330 self.check_result(handled=True)
331
332 @requires_forking
333 def test_forking_not_handled(self):
334 ForkingErrorTestServer(SystemExit)
335 self.check_result(handled=False)
336
337 def check_result(self, handled):
338 with open(os_helper.TESTFN) as log:
339 expected = 'Handler called\n' + 'Error handled\n' * handled
340 self.assertEqual(log.read(), expected)
341
342
343 class ESC[4;38;5;81mBaseErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
344 def __init__(self, exception):
345 self.exception = exception
346 super().__init__((HOST, 0), BadHandler)
347 with socket.create_connection(self.server_address):
348 pass
349 try:
350 self.handle_request()
351 finally:
352 self.server_close()
353 self.wait_done()
354
355 def handle_error(self, request, client_address):
356 with open(os_helper.TESTFN, 'a') as log:
357 log.write('Error handled\n')
358
359 def wait_done(self):
360 pass
361
362
363 class ESC[4;38;5;81mBadHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mBaseRequestHandler):
364 def handle(self):
365 with open(os_helper.TESTFN, 'a') as log:
366 log.write('Handler called\n')
367 raise self.server.exception('Test error')
368
369
370 class ESC[4;38;5;81mThreadingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn,
371 ESC[4;38;5;149mBaseErrorTestServer):
372 def __init__(self, *pos, **kw):
373 self.done = threading.Event()
374 super().__init__(*pos, **kw)
375
376 def shutdown_request(self, *pos, **kw):
377 super().shutdown_request(*pos, **kw)
378 self.done.set()
379
380 def wait_done(self):
381 self.done.wait()
382
383
384 if HAVE_FORKING:
385 class ESC[4;38;5;81mForkingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn, ESC[4;38;5;149mBaseErrorTestServer):
386 pass
387
388
389 class ESC[4;38;5;81mSocketWriterTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
390 def test_basics(self):
391 class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
392 def handle(self):
393 self.server.wfile = self.wfile
394 self.server.wfile_fileno = self.wfile.fileno()
395 self.server.request_fileno = self.request.fileno()
396
397 server = socketserver.TCPServer((HOST, 0), Handler)
398 self.addCleanup(server.server_close)
399 s = socket.socket(
400 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
401 with s:
402 s.connect(server.server_address)
403 server.handle_request()
404 self.assertIsInstance(server.wfile, io.BufferedIOBase)
405 self.assertEqual(server.wfile_fileno, server.request_fileno)
406
407 def test_write(self):
408 # Test that wfile.write() sends data immediately, and that it does
409 # not truncate sends when interrupted by a Unix signal
410 pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
411
412 class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
413 def handle(self):
414 self.server.sent1 = self.wfile.write(b'write data\n')
415 # Should be sent immediately, without requiring flush()
416 self.server.received = self.rfile.readline()
417 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
418 self.server.sent2 = self.wfile.write(big_chunk)
419
420 server = socketserver.TCPServer((HOST, 0), Handler)
421 self.addCleanup(server.server_close)
422 interrupted = threading.Event()
423
424 def signal_handler(signum, frame):
425 interrupted.set()
426
427 original = signal.signal(signal.SIGUSR1, signal_handler)
428 self.addCleanup(signal.signal, signal.SIGUSR1, original)
429 response1 = None
430 received2 = None
431 main_thread = threading.get_ident()
432
433 def run_client():
434 s = socket.socket(server.address_family, socket.SOCK_STREAM,
435 socket.IPPROTO_TCP)
436 with s, s.makefile('rb') as reader:
437 s.connect(server.server_address)
438 nonlocal response1
439 response1 = reader.readline()
440 s.sendall(b'client response\n')
441
442 reader.read(100)
443 # The main thread should now be blocking in a send() syscall.
444 # But in theory, it could get interrupted by other signals,
445 # and then retried. So keep sending the signal in a loop, in
446 # case an earlier signal happens to be delivered at an
447 # inconvenient moment.
448 while True:
449 pthread_kill(main_thread, signal.SIGUSR1)
450 if interrupted.wait(timeout=float(1)):
451 break
452 nonlocal received2
453 received2 = len(reader.read())
454
455 background = threading.Thread(target=run_client)
456 background.start()
457 server.handle_request()
458 background.join()
459 self.assertEqual(server.sent1, len(response1))
460 self.assertEqual(response1, b'write data\n')
461 self.assertEqual(server.received, b'client response\n')
462 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
463 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
464
465
466 class ESC[4;38;5;81mMiscTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
467
468 def test_all(self):
469 # objects defined in the module should be in __all__
470 expected = []
471 for name in dir(socketserver):
472 if not name.startswith('_'):
473 mod_object = getattr(socketserver, name)
474 if getattr(mod_object, '__module__', None) == 'socketserver':
475 expected.append(name)
476 self.assertCountEqual(socketserver.__all__, expected)
477
478 def test_shutdown_request_called_if_verify_request_false(self):
479 # Issue #26309: BaseServer should call shutdown_request even if
480 # verify_request is False
481
482 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
483 def verify_request(self, request, client_address):
484 return False
485
486 shutdown_called = 0
487 def shutdown_request(self, request):
488 self.shutdown_called += 1
489 socketserver.TCPServer.shutdown_request(self, request)
490
491 server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
492 s = socket.socket(server.address_family, socket.SOCK_STREAM)
493 s.connect(server.server_address)
494 s.close()
495 server.handle_request()
496 self.assertEqual(server.shutdown_called, 1)
497 server.server_close()
498
499 def test_threads_reaped(self):
500 """
501 In #37193, users reported a memory leak
502 due to the saving of every request thread. Ensure that
503 not all threads are kept forever.
504 """
505 class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
506 pass
507
508 server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
509 for n in range(10):
510 with socket.create_connection(server.server_address):
511 server.handle_request()
512 self.assertLess(len(server._threads), 10)
513 server.server_close()
514
515
516 if __name__ == "__main__":
517 unittest.main()