(root)/
Python-3.11.7/
Lib/
test/
test_socketserver.py
       1  """
       2  Test suite for socketserver.
       3  """
       4  
       5  import contextlib
       6  import io
       7  import os
       8  import select
       9  import signal
      10  import socket
      11  import tempfile
      12  import threading
      13  import unittest
      14  import socketserver
      15  
      16  import test.support
      17  from test.support import reap_children, verbose
      18  from test.support import os_helper
      19  from test.support import socket_helper
      20  from test.support import threading_helper
      21  
      22  
      23  test.support.requires("network")
      24  test.support.requires_working_socket(module=True)
      25  
      26  
      27  TEST_STR = b"hello world\n"
      28  HOST = socket_helper.HOST
      29  
      30  HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
      31  requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
      32                                              'requires Unix sockets')
      33  HAVE_FORKING = test.support.has_fork_support
      34  requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
      35  
      36  # Remember real select() to avoid interferences with mocking
      37  _real_select = select.select
      38  
      39  def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
      40      r, w, x = _real_select([sock], [], [], timeout)
      41      if sock in r:
      42          return sock.recv(n)
      43      else:
      44          raise RuntimeError("timed out on %r" % (sock,))
      45  
      46  if HAVE_UNIX_SOCKETS and HAVE_FORKING:
      47      class ESC[4;38;5;81mForkingUnixStreamServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn,
      48                                    ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixStreamServer):
      49          pass
      50  
      51      class ESC[4;38;5;81mForkingUnixDatagramServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn,
      52                                      ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mUnixDatagramServer):
      53          pass
      54  
      55  
      56  @contextlib.contextmanager
      57  def simple_subprocess(testcase):
      58      """Tests that a custom child process is not waited on (Issue 1540386)"""
      59      pid = os.fork()
      60      if pid == 0:
      61          # Don't raise an exception; it would be caught by the test harness.
      62          os._exit(72)
      63      try:
      64          yield None
      65      except:
      66          raise
      67      finally:
      68          test.support.wait_process(pid, exitcode=72)
      69  
      70  
      71  class ESC[4;38;5;81mSocketServerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      72      """Test all socket servers."""
      73  
      74      def setUp(self):
      75          self.port_seed = 0
      76          self.test_files = []
      77  
      78      def tearDown(self):
      79          reap_children()
      80  
      81          for fn in self.test_files:
      82              try:
      83                  os.remove(fn)
      84              except OSError:
      85                  pass
      86          self.test_files[:] = []
      87  
      88      def pickaddr(self, proto):
      89          if proto == socket.AF_INET:
      90              return (HOST, 0)
      91          else:
      92              # XXX: We need a way to tell AF_UNIX to pick its own name
      93              # like AF_INET provides port==0.
      94              dir = None
      95              fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
      96              self.test_files.append(fn)
      97              return fn
      98  
      99      def make_server(self, addr, svrcls, hdlrbase):
     100          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msvrcls):
     101              def handle_error(self, request, client_address):
     102                  self.close_request(request)
     103                  raise
     104  
     105          class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149mhdlrbase):
     106              def handle(self):
     107                  line = self.rfile.readline()
     108                  self.wfile.write(line)
     109  
     110          if verbose: print("creating server")
     111          try:
     112              server = MyServer(addr, MyHandler)
     113          except PermissionError as e:
     114              # Issue 29184: cannot bind() a Unix socket on Android.
     115              self.skipTest('Cannot create server (%s, %s): %s' %
     116                            (svrcls, addr, e))
     117          self.assertEqual(server.server_address, server.socket.getsockname())
     118          return server
     119  
     120      @threading_helper.reap_threads
     121      def run_server(self, svrcls, hdlrbase, testfunc):
     122          server = self.make_server(self.pickaddr(svrcls.address_family),
     123                                    svrcls, hdlrbase)
     124          # We had the OS pick a port, so pull the real address out of
     125          # the server.
     126          addr = server.server_address
     127          if verbose:
     128              print("ADDR =", addr)
     129              print("CLASS =", svrcls)
     130  
     131          t = threading.Thread(
     132              name='%s serving' % svrcls,
     133              target=server.serve_forever,
     134              # Short poll interval to make the test finish quickly.
     135              # Time between requests is short enough that we won't wake
     136              # up spuriously too many times.
     137              kwargs={'poll_interval':0.01})
     138          t.daemon = True  # In case this function raises.
     139          t.start()
     140          if verbose: print("server running")
     141          for i in range(3):
     142              if verbose: print("test client", i)
     143              testfunc(svrcls.address_family, addr)
     144          if verbose: print("waiting for server")
     145          server.shutdown()
     146          t.join()
     147          server.server_close()
     148          self.assertEqual(-1, server.socket.fileno())
     149          if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
     150              # bpo-31151: Check that ForkingMixIn.server_close() waits until
     151              # all children completed
     152              self.assertFalse(server.active_children)
     153          if verbose: print("done")
     154  
     155      def stream_examine(self, proto, addr):
     156          with socket.socket(proto, socket.SOCK_STREAM) as s:
     157              s.connect(addr)
     158              s.sendall(TEST_STR)
     159              buf = data = receive(s, 100)
     160              while data and b'\n' not in buf:
     161                  data = receive(s, 100)
     162                  buf += data
     163              self.assertEqual(buf, TEST_STR)
     164  
     165      def dgram_examine(self, proto, addr):
     166          with socket.socket(proto, socket.SOCK_DGRAM) as s:
     167              if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
     168                  s.bind(self.pickaddr(proto))
     169              s.sendto(TEST_STR, addr)
     170              buf = data = receive(s, 100)
     171              while data and b'\n' not in buf:
     172                  data = receive(s, 100)
     173                  buf += data
     174              self.assertEqual(buf, TEST_STR)
     175  
     176      def test_TCPServer(self):
     177          self.run_server(socketserver.TCPServer,
     178                          socketserver.StreamRequestHandler,
     179                          self.stream_examine)
     180  
     181      def test_ThreadingTCPServer(self):
     182          self.run_server(socketserver.ThreadingTCPServer,
     183                          socketserver.StreamRequestHandler,
     184                          self.stream_examine)
     185  
     186      @requires_forking
     187      def test_ForkingTCPServer(self):
     188          with simple_subprocess(self):
     189              self.run_server(socketserver.ForkingTCPServer,
     190                              socketserver.StreamRequestHandler,
     191                              self.stream_examine)
     192  
     193      @requires_unix_sockets
     194      def test_UnixStreamServer(self):
     195          self.run_server(socketserver.UnixStreamServer,
     196                          socketserver.StreamRequestHandler,
     197                          self.stream_examine)
     198  
     199      @requires_unix_sockets
     200      def test_ThreadingUnixStreamServer(self):
     201          self.run_server(socketserver.ThreadingUnixStreamServer,
     202                          socketserver.StreamRequestHandler,
     203                          self.stream_examine)
     204  
     205      @requires_unix_sockets
     206      @requires_forking
     207      def test_ForkingUnixStreamServer(self):
     208          with simple_subprocess(self):
     209              self.run_server(ForkingUnixStreamServer,
     210                              socketserver.StreamRequestHandler,
     211                              self.stream_examine)
     212  
     213      def test_UDPServer(self):
     214          self.run_server(socketserver.UDPServer,
     215                          socketserver.DatagramRequestHandler,
     216                          self.dgram_examine)
     217  
     218      def test_ThreadingUDPServer(self):
     219          self.run_server(socketserver.ThreadingUDPServer,
     220                          socketserver.DatagramRequestHandler,
     221                          self.dgram_examine)
     222  
     223      @requires_forking
     224      def test_ForkingUDPServer(self):
     225          with simple_subprocess(self):
     226              self.run_server(socketserver.ForkingUDPServer,
     227                              socketserver.DatagramRequestHandler,
     228                              self.dgram_examine)
     229  
     230      @requires_unix_sockets
     231      def test_UnixDatagramServer(self):
     232          self.run_server(socketserver.UnixDatagramServer,
     233                          socketserver.DatagramRequestHandler,
     234                          self.dgram_examine)
     235  
     236      @requires_unix_sockets
     237      def test_ThreadingUnixDatagramServer(self):
     238          self.run_server(socketserver.ThreadingUnixDatagramServer,
     239                          socketserver.DatagramRequestHandler,
     240                          self.dgram_examine)
     241  
     242      @requires_unix_sockets
     243      @requires_forking
     244      def test_ForkingUnixDatagramServer(self):
     245          self.run_server(ForkingUnixDatagramServer,
     246                          socketserver.DatagramRequestHandler,
     247                          self.dgram_examine)
     248  
     249      @threading_helper.reap_threads
     250      def test_shutdown(self):
     251          # Issue #2302: shutdown() should always succeed in making an
     252          # other thread leave serve_forever().
     253          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     254              pass
     255  
     256          class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
     257              pass
     258  
     259          threads = []
     260          for i in range(20):
     261              s = MyServer((HOST, 0), MyHandler)
     262              t = threading.Thread(
     263                  name='MyServer serving',
     264                  target=s.serve_forever,
     265                  kwargs={'poll_interval':0.01})
     266              t.daemon = True  # In case this function raises.
     267              threads.append((t, s))
     268          for t, s in threads:
     269              t.start()
     270              s.shutdown()
     271          for t, s in threads:
     272              t.join()
     273              s.server_close()
     274  
     275      def test_close_immediately(self):
     276          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     277              pass
     278  
     279          server = MyServer((HOST, 0), lambda: None)
     280          server.server_close()
     281  
     282      def test_tcpserver_bind_leak(self):
     283          # Issue #22435: the server socket wouldn't be closed if bind()/listen()
     284          # failed.
     285          # Create many servers for which bind() will fail, to see if this result
     286          # in FD exhaustion.
     287          for i in range(1024):
     288              with self.assertRaises(OverflowError):
     289                  socketserver.TCPServer((HOST, -1),
     290                                         socketserver.StreamRequestHandler)
     291  
     292      def test_context_manager(self):
     293          with socketserver.TCPServer((HOST, 0),
     294                                      socketserver.StreamRequestHandler) as server:
     295              pass
     296          self.assertEqual(-1, server.socket.fileno())
     297  
     298  
     299  class ESC[4;38;5;81mErrorHandlerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     300      """Test that the servers pass normal exceptions from the handler to
     301      handle_error(), and that exiting exceptions like SystemExit and
     302      KeyboardInterrupt are not passed."""
     303  
     304      def tearDown(self):
     305          os_helper.unlink(os_helper.TESTFN)
     306  
     307      def test_sync_handled(self):
     308          BaseErrorTestServer(ValueError)
     309          self.check_result(handled=True)
     310  
     311      def test_sync_not_handled(self):
     312          with self.assertRaises(SystemExit):
     313              BaseErrorTestServer(SystemExit)
     314          self.check_result(handled=False)
     315  
     316      def test_threading_handled(self):
     317          ThreadingErrorTestServer(ValueError)
     318          self.check_result(handled=True)
     319  
     320      def test_threading_not_handled(self):
     321          with threading_helper.catch_threading_exception() as cm:
     322              ThreadingErrorTestServer(SystemExit)
     323              self.check_result(handled=False)
     324  
     325              self.assertIs(cm.exc_type, SystemExit)
     326  
     327      @requires_forking
     328      def test_forking_handled(self):
     329          ForkingErrorTestServer(ValueError)
     330          self.check_result(handled=True)
     331  
     332      @requires_forking
     333      def test_forking_not_handled(self):
     334          ForkingErrorTestServer(SystemExit)
     335          self.check_result(handled=False)
     336  
     337      def check_result(self, handled):
     338          with open(os_helper.TESTFN) as log:
     339              expected = 'Handler called\n' + 'Error handled\n' * handled
     340              self.assertEqual(log.read(), expected)
     341  
     342  
     343  class ESC[4;38;5;81mBaseErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     344      def __init__(self, exception):
     345          self.exception = exception
     346          super().__init__((HOST, 0), BadHandler)
     347          with socket.create_connection(self.server_address):
     348              pass
     349          try:
     350              self.handle_request()
     351          finally:
     352              self.server_close()
     353          self.wait_done()
     354  
     355      def handle_error(self, request, client_address):
     356          with open(os_helper.TESTFN, 'a') as log:
     357              log.write('Error handled\n')
     358  
     359      def wait_done(self):
     360          pass
     361  
     362  
     363  class ESC[4;38;5;81mBadHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mBaseRequestHandler):
     364      def handle(self):
     365          with open(os_helper.TESTFN, 'a') as log:
     366              log.write('Handler called\n')
     367          raise self.server.exception('Test error')
     368  
     369  
     370  class ESC[4;38;5;81mThreadingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn,
     371          ESC[4;38;5;149mBaseErrorTestServer):
     372      def __init__(self, *pos, **kw):
     373          self.done = threading.Event()
     374          super().__init__(*pos, **kw)
     375  
     376      def shutdown_request(self, *pos, **kw):
     377          super().shutdown_request(*pos, **kw)
     378          self.done.set()
     379  
     380      def wait_done(self):
     381          self.done.wait()
     382  
     383  
     384  if HAVE_FORKING:
     385      class ESC[4;38;5;81mForkingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn, ESC[4;38;5;149mBaseErrorTestServer):
     386          pass
     387  
     388  
     389  class ESC[4;38;5;81mSocketWriterTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     390      def test_basics(self):
     391          class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
     392              def handle(self):
     393                  self.server.wfile = self.wfile
     394                  self.server.wfile_fileno = self.wfile.fileno()
     395                  self.server.request_fileno = self.request.fileno()
     396  
     397          server = socketserver.TCPServer((HOST, 0), Handler)
     398          self.addCleanup(server.server_close)
     399          s = socket.socket(
     400              server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
     401          with s:
     402              s.connect(server.server_address)
     403          server.handle_request()
     404          self.assertIsInstance(server.wfile, io.BufferedIOBase)
     405          self.assertEqual(server.wfile_fileno, server.request_fileno)
     406  
     407      def test_write(self):
     408          # Test that wfile.write() sends data immediately, and that it does
     409          # not truncate sends when interrupted by a Unix signal
     410          pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
     411  
     412          class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
     413              def handle(self):
     414                  self.server.sent1 = self.wfile.write(b'write data\n')
     415                  # Should be sent immediately, without requiring flush()
     416                  self.server.received = self.rfile.readline()
     417                  big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
     418                  self.server.sent2 = self.wfile.write(big_chunk)
     419  
     420          server = socketserver.TCPServer((HOST, 0), Handler)
     421          self.addCleanup(server.server_close)
     422          interrupted = threading.Event()
     423  
     424          def signal_handler(signum, frame):
     425              interrupted.set()
     426  
     427          original = signal.signal(signal.SIGUSR1, signal_handler)
     428          self.addCleanup(signal.signal, signal.SIGUSR1, original)
     429          response1 = None
     430          received2 = None
     431          main_thread = threading.get_ident()
     432  
     433          def run_client():
     434              s = socket.socket(server.address_family, socket.SOCK_STREAM,
     435                  socket.IPPROTO_TCP)
     436              with s, s.makefile('rb') as reader:
     437                  s.connect(server.server_address)
     438                  nonlocal response1
     439                  response1 = reader.readline()
     440                  s.sendall(b'client response\n')
     441  
     442                  reader.read(100)
     443                  # The main thread should now be blocking in a send() syscall.
     444                  # But in theory, it could get interrupted by other signals,
     445                  # and then retried. So keep sending the signal in a loop, in
     446                  # case an earlier signal happens to be delivered at an
     447                  # inconvenient moment.
     448                  while True:
     449                      pthread_kill(main_thread, signal.SIGUSR1)
     450                      if interrupted.wait(timeout=float(1)):
     451                          break
     452                  nonlocal received2
     453                  received2 = len(reader.read())
     454  
     455          background = threading.Thread(target=run_client)
     456          background.start()
     457          server.handle_request()
     458          background.join()
     459          self.assertEqual(server.sent1, len(response1))
     460          self.assertEqual(response1, b'write data\n')
     461          self.assertEqual(server.received, b'client response\n')
     462          self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
     463          self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
     464  
     465  
     466  class ESC[4;38;5;81mMiscTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     467  
     468      def test_all(self):
     469          # objects defined in the module should be in __all__
     470          expected = []
     471          for name in dir(socketserver):
     472              if not name.startswith('_'):
     473                  mod_object = getattr(socketserver, name)
     474                  if getattr(mod_object, '__module__', None) == 'socketserver':
     475                      expected.append(name)
     476          self.assertCountEqual(socketserver.__all__, expected)
     477  
     478      def test_shutdown_request_called_if_verify_request_false(self):
     479          # Issue #26309: BaseServer should call shutdown_request even if
     480          # verify_request is False
     481  
     482          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     483              def verify_request(self, request, client_address):
     484                  return False
     485  
     486              shutdown_called = 0
     487              def shutdown_request(self, request):
     488                  self.shutdown_called += 1
     489                  socketserver.TCPServer.shutdown_request(self, request)
     490  
     491          server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
     492          s = socket.socket(server.address_family, socket.SOCK_STREAM)
     493          s.connect(server.server_address)
     494          s.close()
     495          server.handle_request()
     496          self.assertEqual(server.shutdown_called, 1)
     497          server.server_close()
     498  
     499      def test_threads_reaped(self):
     500          """
     501          In #37193, users reported a memory leak
     502          due to the saving of every request thread. Ensure that
     503          not all threads are kept forever.
     504          """
     505          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     506              pass
     507  
     508          server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
     509          for n in range(10):
     510              with socket.create_connection(server.server_address):
     511                  server.handle_request()
     512          self.assertLess(len(server._threads), 10)
     513          server.server_close()
     514  
     515  
     516  if __name__ == "__main__":
     517      unittest.main()