(root)/
Python-3.12.0/
Lib/
test/
test_socketserver.py
       1  """
       2  Test suite for socketserver.
       3  """
       4  
       5  import contextlib
       6  import io
       7  import os
       8  import select
       9  import signal
      10  import socket
      11  import threading
      12  import unittest
      13  import socketserver
      14  
      15  import test.support
      16  from test.support import reap_children, verbose
      17  from test.support import os_helper
      18  from test.support import socket_helper
      19  from test.support import threading_helper
      20  
      21  
      22  test.support.requires("network")
      23  test.support.requires_working_socket(module=True)
      24  
      25  
      26  TEST_STR = b"hello world\n"
      27  HOST = socket_helper.HOST
      28  
      29  HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
      30  requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
      31                                              'requires Unix sockets')
      32  HAVE_FORKING = test.support.has_fork_support
      33  requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
      34  
      35  def signal_alarm(n):
      36      """Call signal.alarm when it exists (i.e. not on Windows)."""
      37      if hasattr(signal, 'alarm'):
      38          signal.alarm(n)
      39  
      40  # Remember real select() to avoid interferences with mocking
      41  _real_select = select.select
      42  
      43  def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
      44      r, w, x = _real_select([sock], [], [], timeout)
      45      if sock in r:
      46          return sock.recv(n)
      47      else:
      48          raise RuntimeError("timed out on %r" % (sock,))
      49  
      50  
      51  @test.support.requires_fork()
      52  @contextlib.contextmanager
      53  def simple_subprocess(testcase):
      54      """Tests that a custom child process is not waited on (Issue 1540386)"""
      55      pid = os.fork()
      56      if pid == 0:
      57          # Don't raise an exception; it would be caught by the test harness.
      58          os._exit(72)
      59      try:
      60          yield None
      61      except:
      62          raise
      63      finally:
      64          test.support.wait_process(pid, exitcode=72)
      65  
      66  
      67  class ESC[4;38;5;81mSocketServerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
      68      """Test all socket servers."""
      69  
      70      def setUp(self):
      71          signal_alarm(60)  # Kill deadlocks after 60 seconds.
      72          self.port_seed = 0
      73          self.test_files = []
      74  
      75      def tearDown(self):
      76          signal_alarm(0)  # Didn't deadlock.
      77          reap_children()
      78  
      79          for fn in self.test_files:
      80              try:
      81                  os.remove(fn)
      82              except OSError:
      83                  pass
      84          self.test_files[:] = []
      85  
      86      def pickaddr(self, proto):
      87          if proto == socket.AF_INET:
      88              return (HOST, 0)
      89          else:
      90              # XXX: We need a way to tell AF_UNIX to pick its own name
      91              # like AF_INET provides port==0.
      92              fn = socket_helper.create_unix_domain_name()
      93              self.test_files.append(fn)
      94              return fn
      95  
      96      def make_server(self, addr, svrcls, hdlrbase):
      97          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msvrcls):
      98              def handle_error(self, request, client_address):
      99                  self.close_request(request)
     100                  raise
     101  
     102          class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149mhdlrbase):
     103              def handle(self):
     104                  line = self.rfile.readline()
     105                  self.wfile.write(line)
     106  
     107          if verbose: print("creating server")
     108          try:
     109              server = MyServer(addr, MyHandler)
     110          except PermissionError as e:
     111              # Issue 29184: cannot bind() a Unix socket on Android.
     112              self.skipTest('Cannot create server (%s, %s): %s' %
     113                            (svrcls, addr, e))
     114          self.assertEqual(server.server_address, server.socket.getsockname())
     115          return server
     116  
     117      @threading_helper.reap_threads
     118      def run_server(self, svrcls, hdlrbase, testfunc):
     119          server = self.make_server(self.pickaddr(svrcls.address_family),
     120                                    svrcls, hdlrbase)
     121          # We had the OS pick a port, so pull the real address out of
     122          # the server.
     123          addr = server.server_address
     124          if verbose:
     125              print("ADDR =", addr)
     126              print("CLASS =", svrcls)
     127  
     128          t = threading.Thread(
     129              name='%s serving' % svrcls,
     130              target=server.serve_forever,
     131              # Short poll interval to make the test finish quickly.
     132              # Time between requests is short enough that we won't wake
     133              # up spuriously too many times.
     134              kwargs={'poll_interval':0.01})
     135          t.daemon = True  # In case this function raises.
     136          t.start()
     137          if verbose: print("server running")
     138          for i in range(3):
     139              if verbose: print("test client", i)
     140              testfunc(svrcls.address_family, addr)
     141          if verbose: print("waiting for server")
     142          server.shutdown()
     143          t.join()
     144          server.server_close()
     145          self.assertEqual(-1, server.socket.fileno())
     146          if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
     147              # bpo-31151: Check that ForkingMixIn.server_close() waits until
     148              # all children completed
     149              self.assertFalse(server.active_children)
     150          if verbose: print("done")
     151  
     152      def stream_examine(self, proto, addr):
     153          with socket.socket(proto, socket.SOCK_STREAM) as s:
     154              s.connect(addr)
     155              s.sendall(TEST_STR)
     156              buf = data = receive(s, 100)
     157              while data and b'\n' not in buf:
     158                  data = receive(s, 100)
     159                  buf += data
     160              self.assertEqual(buf, TEST_STR)
     161  
     162      def dgram_examine(self, proto, addr):
     163          with socket.socket(proto, socket.SOCK_DGRAM) as s:
     164              if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
     165                  s.bind(self.pickaddr(proto))
     166              s.sendto(TEST_STR, addr)
     167              buf = data = receive(s, 100)
     168              while data and b'\n' not in buf:
     169                  data = receive(s, 100)
     170                  buf += data
     171              self.assertEqual(buf, TEST_STR)
     172  
     173      def test_TCPServer(self):
     174          self.run_server(socketserver.TCPServer,
     175                          socketserver.StreamRequestHandler,
     176                          self.stream_examine)
     177  
     178      def test_ThreadingTCPServer(self):
     179          self.run_server(socketserver.ThreadingTCPServer,
     180                          socketserver.StreamRequestHandler,
     181                          self.stream_examine)
     182  
     183      @requires_forking
     184      def test_ForkingTCPServer(self):
     185          with simple_subprocess(self):
     186              self.run_server(socketserver.ForkingTCPServer,
     187                              socketserver.StreamRequestHandler,
     188                              self.stream_examine)
     189  
     190      @requires_unix_sockets
     191      def test_UnixStreamServer(self):
     192          self.run_server(socketserver.UnixStreamServer,
     193                          socketserver.StreamRequestHandler,
     194                          self.stream_examine)
     195  
     196      @requires_unix_sockets
     197      def test_ThreadingUnixStreamServer(self):
     198          self.run_server(socketserver.ThreadingUnixStreamServer,
     199                          socketserver.StreamRequestHandler,
     200                          self.stream_examine)
     201  
     202      @requires_unix_sockets
     203      @requires_forking
     204      def test_ForkingUnixStreamServer(self):
     205          with simple_subprocess(self):
     206              self.run_server(socketserver.ForkingUnixStreamServer,
     207                              socketserver.StreamRequestHandler,
     208                              self.stream_examine)
     209  
     210      def test_UDPServer(self):
     211          self.run_server(socketserver.UDPServer,
     212                          socketserver.DatagramRequestHandler,
     213                          self.dgram_examine)
     214  
     215      def test_ThreadingUDPServer(self):
     216          self.run_server(socketserver.ThreadingUDPServer,
     217                          socketserver.DatagramRequestHandler,
     218                          self.dgram_examine)
     219  
     220      @requires_forking
     221      def test_ForkingUDPServer(self):
     222          with simple_subprocess(self):
     223              self.run_server(socketserver.ForkingUDPServer,
     224                              socketserver.DatagramRequestHandler,
     225                              self.dgram_examine)
     226  
     227      @requires_unix_sockets
     228      def test_UnixDatagramServer(self):
     229          self.run_server(socketserver.UnixDatagramServer,
     230                          socketserver.DatagramRequestHandler,
     231                          self.dgram_examine)
     232  
     233      @requires_unix_sockets
     234      def test_ThreadingUnixDatagramServer(self):
     235          self.run_server(socketserver.ThreadingUnixDatagramServer,
     236                          socketserver.DatagramRequestHandler,
     237                          self.dgram_examine)
     238  
     239      @requires_unix_sockets
     240      @requires_forking
     241      def test_ForkingUnixDatagramServer(self):
     242          self.run_server(socketserver.ForkingUnixDatagramServer,
     243                          socketserver.DatagramRequestHandler,
     244                          self.dgram_examine)
     245  
     246      @threading_helper.reap_threads
     247      def test_shutdown(self):
     248          # Issue #2302: shutdown() should always succeed in making an
     249          # other thread leave serve_forever().
     250          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     251              pass
     252  
     253          class ESC[4;38;5;81mMyHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
     254              pass
     255  
     256          threads = []
     257          for i in range(20):
     258              s = MyServer((HOST, 0), MyHandler)
     259              t = threading.Thread(
     260                  name='MyServer serving',
     261                  target=s.serve_forever,
     262                  kwargs={'poll_interval':0.01})
     263              t.daemon = True  # In case this function raises.
     264              threads.append((t, s))
     265          for t, s in threads:
     266              t.start()
     267              s.shutdown()
     268          for t, s in threads:
     269              t.join()
     270              s.server_close()
     271  
     272      def test_close_immediately(self):
     273          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     274              pass
     275  
     276          server = MyServer((HOST, 0), lambda: None)
     277          server.server_close()
     278  
     279      def test_tcpserver_bind_leak(self):
     280          # Issue #22435: the server socket wouldn't be closed if bind()/listen()
     281          # failed.
     282          # Create many servers for which bind() will fail, to see if this result
     283          # in FD exhaustion.
     284          for i in range(1024):
     285              with self.assertRaises(OverflowError):
     286                  socketserver.TCPServer((HOST, -1),
     287                                         socketserver.StreamRequestHandler)
     288  
     289      def test_context_manager(self):
     290          with socketserver.TCPServer((HOST, 0),
     291                                      socketserver.StreamRequestHandler) as server:
     292              pass
     293          self.assertEqual(-1, server.socket.fileno())
     294  
     295  
     296  class ESC[4;38;5;81mErrorHandlerTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     297      """Test that the servers pass normal exceptions from the handler to
     298      handle_error(), and that exiting exceptions like SystemExit and
     299      KeyboardInterrupt are not passed."""
     300  
     301      def tearDown(self):
     302          os_helper.unlink(os_helper.TESTFN)
     303  
     304      def test_sync_handled(self):
     305          BaseErrorTestServer(ValueError)
     306          self.check_result(handled=True)
     307  
     308      def test_sync_not_handled(self):
     309          with self.assertRaises(SystemExit):
     310              BaseErrorTestServer(SystemExit)
     311          self.check_result(handled=False)
     312  
     313      def test_threading_handled(self):
     314          ThreadingErrorTestServer(ValueError)
     315          self.check_result(handled=True)
     316  
     317      def test_threading_not_handled(self):
     318          with threading_helper.catch_threading_exception() as cm:
     319              ThreadingErrorTestServer(SystemExit)
     320              self.check_result(handled=False)
     321  
     322              self.assertIs(cm.exc_type, SystemExit)
     323  
     324      @requires_forking
     325      def test_forking_handled(self):
     326          ForkingErrorTestServer(ValueError)
     327          self.check_result(handled=True)
     328  
     329      @requires_forking
     330      def test_forking_not_handled(self):
     331          ForkingErrorTestServer(SystemExit)
     332          self.check_result(handled=False)
     333  
     334      def check_result(self, handled):
     335          with open(os_helper.TESTFN) as log:
     336              expected = 'Handler called\n' + 'Error handled\n' * handled
     337              self.assertEqual(log.read(), expected)
     338  
     339  
     340  class ESC[4;38;5;81mBaseErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     341      def __init__(self, exception):
     342          self.exception = exception
     343          super().__init__((HOST, 0), BadHandler)
     344          with socket.create_connection(self.server_address):
     345              pass
     346          try:
     347              self.handle_request()
     348          finally:
     349              self.server_close()
     350          self.wait_done()
     351  
     352      def handle_error(self, request, client_address):
     353          with open(os_helper.TESTFN, 'a') as log:
     354              log.write('Error handled\n')
     355  
     356      def wait_done(self):
     357          pass
     358  
     359  
     360  class ESC[4;38;5;81mBadHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mBaseRequestHandler):
     361      def handle(self):
     362          with open(os_helper.TESTFN, 'a') as log:
     363              log.write('Handler called\n')
     364          raise self.server.exception('Test error')
     365  
     366  
     367  class ESC[4;38;5;81mThreadingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn,
     368          ESC[4;38;5;149mBaseErrorTestServer):
     369      def __init__(self, *pos, **kw):
     370          self.done = threading.Event()
     371          super().__init__(*pos, **kw)
     372  
     373      def shutdown_request(self, *pos, **kw):
     374          super().shutdown_request(*pos, **kw)
     375          self.done.set()
     376  
     377      def wait_done(self):
     378          self.done.wait()
     379  
     380  
     381  if HAVE_FORKING:
     382      class ESC[4;38;5;81mForkingErrorTestServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mForkingMixIn, ESC[4;38;5;149mBaseErrorTestServer):
     383          pass
     384  
     385  
     386  class ESC[4;38;5;81mSocketWriterTest(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     387      def test_basics(self):
     388          class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
     389              def handle(self):
     390                  self.server.wfile = self.wfile
     391                  self.server.wfile_fileno = self.wfile.fileno()
     392                  self.server.request_fileno = self.request.fileno()
     393  
     394          server = socketserver.TCPServer((HOST, 0), Handler)
     395          self.addCleanup(server.server_close)
     396          s = socket.socket(
     397              server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
     398          with s:
     399              s.connect(server.server_address)
     400          server.handle_request()
     401          self.assertIsInstance(server.wfile, io.BufferedIOBase)
     402          self.assertEqual(server.wfile_fileno, server.request_fileno)
     403  
     404      def test_write(self):
     405          # Test that wfile.write() sends data immediately, and that it does
     406          # not truncate sends when interrupted by a Unix signal
     407          pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
     408  
     409          class ESC[4;38;5;81mHandler(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mStreamRequestHandler):
     410              def handle(self):
     411                  self.server.sent1 = self.wfile.write(b'write data\n')
     412                  # Should be sent immediately, without requiring flush()
     413                  self.server.received = self.rfile.readline()
     414                  big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
     415                  self.server.sent2 = self.wfile.write(big_chunk)
     416  
     417          server = socketserver.TCPServer((HOST, 0), Handler)
     418          self.addCleanup(server.server_close)
     419          interrupted = threading.Event()
     420  
     421          def signal_handler(signum, frame):
     422              interrupted.set()
     423  
     424          original = signal.signal(signal.SIGUSR1, signal_handler)
     425          self.addCleanup(signal.signal, signal.SIGUSR1, original)
     426          response1 = None
     427          received2 = None
     428          main_thread = threading.get_ident()
     429  
     430          def run_client():
     431              s = socket.socket(server.address_family, socket.SOCK_STREAM,
     432                  socket.IPPROTO_TCP)
     433              with s, s.makefile('rb') as reader:
     434                  s.connect(server.server_address)
     435                  nonlocal response1
     436                  response1 = reader.readline()
     437                  s.sendall(b'client response\n')
     438  
     439                  reader.read(100)
     440                  # The main thread should now be blocking in a send() syscall.
     441                  # But in theory, it could get interrupted by other signals,
     442                  # and then retried. So keep sending the signal in a loop, in
     443                  # case an earlier signal happens to be delivered at an
     444                  # inconvenient moment.
     445                  while True:
     446                      pthread_kill(main_thread, signal.SIGUSR1)
     447                      if interrupted.wait(timeout=float(1)):
     448                          break
     449                  nonlocal received2
     450                  received2 = len(reader.read())
     451  
     452          background = threading.Thread(target=run_client)
     453          background.start()
     454          server.handle_request()
     455          background.join()
     456          self.assertEqual(server.sent1, len(response1))
     457          self.assertEqual(response1, b'write data\n')
     458          self.assertEqual(server.received, b'client response\n')
     459          self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
     460          self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
     461  
     462  
     463  class ESC[4;38;5;81mMiscTestCase(ESC[4;38;5;149munittestESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
     464  
     465      def test_all(self):
     466          # objects defined in the module should be in __all__
     467          expected = []
     468          for name in dir(socketserver):
     469              if not name.startswith('_'):
     470                  mod_object = getattr(socketserver, name)
     471                  if getattr(mod_object, '__module__', None) == 'socketserver':
     472                      expected.append(name)
     473          self.assertCountEqual(socketserver.__all__, expected)
     474  
     475      def test_shutdown_request_called_if_verify_request_false(self):
     476          # Issue #26309: BaseServer should call shutdown_request even if
     477          # verify_request is False
     478  
     479          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     480              def verify_request(self, request, client_address):
     481                  return False
     482  
     483              shutdown_called = 0
     484              def shutdown_request(self, request):
     485                  self.shutdown_called += 1
     486                  socketserver.TCPServer.shutdown_request(self, request)
     487  
     488          server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
     489          s = socket.socket(server.address_family, socket.SOCK_STREAM)
     490          s.connect(server.server_address)
     491          s.close()
     492          server.handle_request()
     493          self.assertEqual(server.shutdown_called, 1)
     494          server.server_close()
     495  
     496      def test_threads_reaped(self):
     497          """
     498          In #37193, users reported a memory leak
     499          due to the saving of every request thread. Ensure that
     500          not all threads are kept forever.
     501          """
     502          class ESC[4;38;5;81mMyServer(ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mThreadingMixIn, ESC[4;38;5;149msocketserverESC[4;38;5;149m.ESC[4;38;5;149mTCPServer):
     503              pass
     504  
     505          server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
     506          for n in range(10):
     507              with socket.create_connection(server.server_address):
     508                  server.handle_request()
     509          self.assertLess(len(server._threads), 10)
     510          server.server_close()
     511  
     512  
     513  if __name__ == "__main__":
     514      unittest.main()