(root)/
Python-3.12.0/
Lib/
test/
test_wsgiref.py
       1  from unittest import mock
       2  from test import support
       3  from test.support import socket_helper
       4  from test.test_httpservers import NoLogRequestHandler
       5  from unittest import TestCase
       6  from wsgiref.util import setup_testing_defaults
       7  from wsgiref.headers import Headers
       8  from wsgiref.handlers import BaseHandler, BaseCGIHandler, SimpleHandler
       9  from wsgiref import util
      10  from wsgiref.validate import validator
      11  from wsgiref.simple_server import WSGIServer, WSGIRequestHandler
      12  from wsgiref.simple_server import make_server
      13  from http.client import HTTPConnection
      14  from io import StringIO, BytesIO, BufferedReader
      15  from socketserver import BaseServer
      16  from platform import python_implementation
      17  
      18  import os
      19  import re
      20  import signal
      21  import sys
      22  import threading
      23  import unittest
      24  
      25  
      26  class ESC[4;38;5;81mMockServer(ESC[4;38;5;149mWSGIServer):
      27      """Non-socket HTTP server"""
      28  
      29      def __init__(self, server_address, RequestHandlerClass):
      30          BaseServer.__init__(self, server_address, RequestHandlerClass)
      31          self.server_bind()
      32  
      33      def server_bind(self):
      34          host, port = self.server_address
      35          self.server_name = host
      36          self.server_port = port
      37          self.setup_environ()
      38  
      39  
      40  class ESC[4;38;5;81mMockHandler(ESC[4;38;5;149mWSGIRequestHandler):
      41      """Non-socket HTTP handler"""
      42      def setup(self):
      43          self.connection = self.request
      44          self.rfile, self.wfile = self.connection
      45  
      46      def finish(self):
      47          pass
      48  
      49  
      50  def hello_app(environ,start_response):
      51      start_response("200 OK", [
      52          ('Content-Type','text/plain'),
      53          ('Date','Mon, 05 Jun 2006 18:49:54 GMT')
      54      ])
      55      return [b"Hello, world!"]
      56  
      57  
      58  def header_app(environ, start_response):
      59      start_response("200 OK", [
      60          ('Content-Type', 'text/plain'),
      61          ('Date', 'Mon, 05 Jun 2006 18:49:54 GMT')
      62      ])
      63      return [';'.join([
      64          environ['HTTP_X_TEST_HEADER'], environ['QUERY_STRING'],
      65          environ['PATH_INFO']
      66      ]).encode('iso-8859-1')]
      67  
      68  
      69  def run_amock(app=hello_app, data=b"GET / HTTP/1.0\n\n"):
      70      server = make_server("", 80, app, MockServer, MockHandler)
      71      inp = BufferedReader(BytesIO(data))
      72      out = BytesIO()
      73      olderr = sys.stderr
      74      err = sys.stderr = StringIO()
      75  
      76      try:
      77          server.finish_request((inp, out), ("127.0.0.1",8888))
      78      finally:
      79          sys.stderr = olderr
      80  
      81      return out.getvalue(), err.getvalue()
      82  
      83  
      84  def compare_generic_iter(make_it, match):
      85      """Utility to compare a generic iterator with an iterable
      86  
      87      This tests the iterator using iter()/next().
      88      'make_it' must be a function returning a fresh
      89      iterator to be tested (since this may test the iterator twice)."""
      90  
      91      it = make_it()
      92      if not iter(it) is it:
      93          raise AssertionError
      94      for item in match:
      95          if not next(it) == item:
      96              raise AssertionError
      97      try:
      98          next(it)
      99      except StopIteration:
     100          pass
     101      else:
     102          raise AssertionError("Too many items from .__next__()", it)
     103  
     104  
     105  class ESC[4;38;5;81mIntegrationTests(ESC[4;38;5;149mTestCase):
     106  
     107      def check_hello(self, out, has_length=True):
     108          pyver = (python_implementation() + "/" +
     109                  sys.version.split()[0])
     110          self.assertEqual(out,
     111              ("HTTP/1.0 200 OK\r\n"
     112              "Server: WSGIServer/0.2 " + pyver +"\r\n"
     113              "Content-Type: text/plain\r\n"
     114              "Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" +
     115              (has_length and  "Content-Length: 13\r\n" or "") +
     116              "\r\n"
     117              "Hello, world!").encode("iso-8859-1")
     118          )
     119  
     120      def test_plain_hello(self):
     121          out, err = run_amock()
     122          self.check_hello(out)
     123  
     124      def test_environ(self):
     125          request = (
     126              b"GET /p%61th/?query=test HTTP/1.0\n"
     127              b"X-Test-Header: Python test \n"
     128              b"X-Test-Header: Python test 2\n"
     129              b"Content-Length: 0\n\n"
     130          )
     131          out, err = run_amock(header_app, request)
     132          self.assertEqual(
     133              out.splitlines()[-1],
     134              b"Python test,Python test 2;query=test;/path/"
     135          )
     136  
     137      def test_request_length(self):
     138          out, err = run_amock(data=b"GET " + (b"x" * 65537) + b" HTTP/1.0\n\n")
     139          self.assertEqual(out.splitlines()[0],
     140                           b"HTTP/1.0 414 Request-URI Too Long")
     141  
     142      def test_validated_hello(self):
     143          out, err = run_amock(validator(hello_app))
     144          # the middleware doesn't support len(), so content-length isn't there
     145          self.check_hello(out, has_length=False)
     146  
     147      def test_simple_validation_error(self):
     148          def bad_app(environ,start_response):
     149              start_response("200 OK", ('Content-Type','text/plain'))
     150              return ["Hello, world!"]
     151          out, err = run_amock(validator(bad_app))
     152          self.assertTrue(out.endswith(
     153              b"A server error occurred.  Please contact the administrator."
     154          ))
     155          self.assertEqual(
     156              err.splitlines()[-2],
     157              "AssertionError: Headers (('Content-Type', 'text/plain')) must"
     158              " be of type list: <class 'tuple'>"
     159          )
     160  
     161      def test_status_validation_errors(self):
     162          def create_bad_app(status):
     163              def bad_app(environ, start_response):
     164                  start_response(status, [("Content-Type", "text/plain; charset=utf-8")])
     165                  return [b"Hello, world!"]
     166              return bad_app
     167  
     168          tests = [
     169              ('200', 'AssertionError: Status must be at least 4 characters'),
     170              ('20X OK', 'AssertionError: Status message must begin w/3-digit code'),
     171              ('200OK', 'AssertionError: Status message must have a space after code'),
     172          ]
     173  
     174          for status, exc_message in tests:
     175              with self.subTest(status=status):
     176                  out, err = run_amock(create_bad_app(status))
     177                  self.assertTrue(out.endswith(
     178                      b"A server error occurred.  Please contact the administrator."
     179                  ))
     180                  self.assertEqual(err.splitlines()[-2], exc_message)
     181  
     182      def test_wsgi_input(self):
     183          def bad_app(e,s):
     184              e["wsgi.input"].read()
     185              s("200 OK", [("Content-Type", "text/plain; charset=utf-8")])
     186              return [b"data"]
     187          out, err = run_amock(validator(bad_app))
     188          self.assertTrue(out.endswith(
     189              b"A server error occurred.  Please contact the administrator."
     190          ))
     191          self.assertEqual(
     192              err.splitlines()[-2], "AssertionError"
     193          )
     194  
     195      def test_bytes_validation(self):
     196          def app(e, s):
     197              s("200 OK", [
     198                  ("Content-Type", "text/plain; charset=utf-8"),
     199                  ("Date", "Wed, 24 Dec 2008 13:29:32 GMT"),
     200                  ])
     201              return [b"data"]
     202          out, err = run_amock(validator(app))
     203          self.assertTrue(err.endswith('"GET / HTTP/1.0" 200 4\n'))
     204          ver = sys.version.split()[0].encode('ascii')
     205          py  = python_implementation().encode('ascii')
     206          pyver = py + b"/" + ver
     207          self.assertEqual(
     208                  b"HTTP/1.0 200 OK\r\n"
     209                  b"Server: WSGIServer/0.2 "+ pyver + b"\r\n"
     210                  b"Content-Type: text/plain; charset=utf-8\r\n"
     211                  b"Date: Wed, 24 Dec 2008 13:29:32 GMT\r\n"
     212                  b"\r\n"
     213                  b"data",
     214                  out)
     215  
     216      def test_cp1252_url(self):
     217          def app(e, s):
     218              s("200 OK", [
     219                  ("Content-Type", "text/plain"),
     220                  ("Date", "Wed, 24 Dec 2008 13:29:32 GMT"),
     221                  ])
     222              # PEP3333 says environ variables are decoded as latin1.
     223              # Encode as latin1 to get original bytes
     224              return [e["PATH_INFO"].encode("latin1")]
     225  
     226          out, err = run_amock(
     227              validator(app), data=b"GET /\x80%80 HTTP/1.0")
     228          self.assertEqual(
     229              [
     230                  b"HTTP/1.0 200 OK",
     231                  mock.ANY,
     232                  b"Content-Type: text/plain",
     233                  b"Date: Wed, 24 Dec 2008 13:29:32 GMT",
     234                  b"",
     235                  b"/\x80\x80",
     236              ],
     237              out.splitlines())
     238  
     239      def test_interrupted_write(self):
     240          # BaseHandler._write() and _flush() have to write all data, even if
     241          # it takes multiple send() calls.  Test this by interrupting a send()
     242          # call with a Unix signal.
     243          pthread_kill = support.get_attribute(signal, "pthread_kill")
     244  
     245          def app(environ, start_response):
     246              start_response("200 OK", [])
     247              return [b'\0' * support.SOCK_MAX_SIZE]
     248  
     249          class ESC[4;38;5;81mWsgiHandler(ESC[4;38;5;149mNoLogRequestHandler, ESC[4;38;5;149mWSGIRequestHandler):
     250              pass
     251  
     252          server = make_server(socket_helper.HOST, 0, app, handler_class=WsgiHandler)
     253          self.addCleanup(server.server_close)
     254          interrupted = threading.Event()
     255  
     256          def signal_handler(signum, frame):
     257              interrupted.set()
     258  
     259          original = signal.signal(signal.SIGUSR1, signal_handler)
     260          self.addCleanup(signal.signal, signal.SIGUSR1, original)
     261          received = None
     262          main_thread = threading.get_ident()
     263  
     264          def run_client():
     265              http = HTTPConnection(*server.server_address)
     266              http.request("GET", "/")
     267              with http.getresponse() as response:
     268                  response.read(100)
     269                  # The main thread should now be blocking in a send() system
     270                  # call.  But in theory, it could get interrupted by other
     271                  # signals, and then retried.  So keep sending the signal in a
     272                  # loop, in case an earlier signal happens to be delivered at
     273                  # an inconvenient moment.
     274                  while True:
     275                      pthread_kill(main_thread, signal.SIGUSR1)
     276                      if interrupted.wait(timeout=float(1)):
     277                          break
     278                  nonlocal received
     279                  received = len(response.read())
     280              http.close()
     281  
     282          background = threading.Thread(target=run_client)
     283          background.start()
     284          server.handle_request()
     285          background.join()
     286          self.assertEqual(received, support.SOCK_MAX_SIZE - 100)
     287  
     288  
     289  class ESC[4;38;5;81mUtilityTests(ESC[4;38;5;149mTestCase):
     290  
     291      def checkShift(self,sn_in,pi_in,part,sn_out,pi_out):
     292          env = {'SCRIPT_NAME':sn_in,'PATH_INFO':pi_in}
     293          util.setup_testing_defaults(env)
     294          self.assertEqual(util.shift_path_info(env),part)
     295          self.assertEqual(env['PATH_INFO'],pi_out)
     296          self.assertEqual(env['SCRIPT_NAME'],sn_out)
     297          return env
     298  
     299      def checkDefault(self, key, value, alt=None):
     300          # Check defaulting when empty
     301          env = {}
     302          util.setup_testing_defaults(env)
     303          if isinstance(value, StringIO):
     304              self.assertIsInstance(env[key], StringIO)
     305          elif isinstance(value,BytesIO):
     306              self.assertIsInstance(env[key],BytesIO)
     307          else:
     308              self.assertEqual(env[key], value)
     309  
     310          # Check existing value
     311          env = {key:alt}
     312          util.setup_testing_defaults(env)
     313          self.assertIs(env[key], alt)
     314  
     315      def checkCrossDefault(self,key,value,**kw):
     316          util.setup_testing_defaults(kw)
     317          self.assertEqual(kw[key],value)
     318  
     319      def checkAppURI(self,uri,**kw):
     320          util.setup_testing_defaults(kw)
     321          self.assertEqual(util.application_uri(kw),uri)
     322  
     323      def checkReqURI(self,uri,query=1,**kw):
     324          util.setup_testing_defaults(kw)
     325          self.assertEqual(util.request_uri(kw,query),uri)
     326  
     327      def checkFW(self,text,size,match):
     328  
     329          def make_it(text=text,size=size):
     330              return util.FileWrapper(StringIO(text),size)
     331  
     332          compare_generic_iter(make_it,match)
     333  
     334          it = make_it()
     335          self.assertFalse(it.filelike.closed)
     336  
     337          for item in it:
     338              pass
     339  
     340          self.assertFalse(it.filelike.closed)
     341  
     342          it.close()
     343          self.assertTrue(it.filelike.closed)
     344  
     345      def testSimpleShifts(self):
     346          self.checkShift('','/', '', '/', '')
     347          self.checkShift('','/x', 'x', '/x', '')
     348          self.checkShift('/','', None, '/', '')
     349          self.checkShift('/a','/x/y', 'x', '/a/x', '/y')
     350          self.checkShift('/a','/x/',  'x', '/a/x', '/')
     351  
     352      def testNormalizedShifts(self):
     353          self.checkShift('/a/b', '/../y', '..', '/a', '/y')
     354          self.checkShift('', '/../y', '..', '', '/y')
     355          self.checkShift('/a/b', '//y', 'y', '/a/b/y', '')
     356          self.checkShift('/a/b', '//y/', 'y', '/a/b/y', '/')
     357          self.checkShift('/a/b', '/./y', 'y', '/a/b/y', '')
     358          self.checkShift('/a/b', '/./y/', 'y', '/a/b/y', '/')
     359          self.checkShift('/a/b', '///./..//y/.//', '..', '/a', '/y/')
     360          self.checkShift('/a/b', '///', '', '/a/b/', '')
     361          self.checkShift('/a/b', '/.//', '', '/a/b/', '')
     362          self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/')
     363          self.checkShift('/a/b', '/.', None, '/a/b', '')
     364  
     365      def testDefaults(self):
     366          for key, value in [
     367              ('SERVER_NAME','127.0.0.1'),
     368              ('SERVER_PORT', '80'),
     369              ('SERVER_PROTOCOL','HTTP/1.0'),
     370              ('HTTP_HOST','127.0.0.1'),
     371              ('REQUEST_METHOD','GET'),
     372              ('SCRIPT_NAME',''),
     373              ('PATH_INFO','/'),
     374              ('wsgi.version', (1,0)),
     375              ('wsgi.run_once', 0),
     376              ('wsgi.multithread', 0),
     377              ('wsgi.multiprocess', 0),
     378              ('wsgi.input', BytesIO()),
     379              ('wsgi.errors', StringIO()),
     380              ('wsgi.url_scheme','http'),
     381          ]:
     382              self.checkDefault(key,value)
     383  
     384      def testCrossDefaults(self):
     385          self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar")
     386          self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on")
     387          self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="1")
     388          self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="yes")
     389          self.checkCrossDefault('wsgi.url_scheme',"http",HTTPS="foo")
     390          self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo")
     391          self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on")
     392  
     393      def testGuessScheme(self):
     394          self.assertEqual(util.guess_scheme({}), "http")
     395          self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http")
     396          self.assertEqual(util.guess_scheme({'HTTPS':"on"}), "https")
     397          self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https")
     398          self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https")
     399  
     400      def testAppURIs(self):
     401          self.checkAppURI("http://127.0.0.1/")
     402          self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
     403          self.checkAppURI("http://127.0.0.1/sp%E4m", SCRIPT_NAME="/sp\xe4m")
     404          self.checkAppURI("http://spam.example.com:2071/",
     405              HTTP_HOST="spam.example.com:2071", SERVER_PORT="2071")
     406          self.checkAppURI("http://spam.example.com/",
     407              SERVER_NAME="spam.example.com")
     408          self.checkAppURI("http://127.0.0.1/",
     409              HTTP_HOST="127.0.0.1", SERVER_NAME="spam.example.com")
     410          self.checkAppURI("https://127.0.0.1/", HTTPS="on")
     411          self.checkAppURI("http://127.0.0.1:8000/", SERVER_PORT="8000",
     412              HTTP_HOST=None)
     413  
     414      def testReqURIs(self):
     415          self.checkReqURI("http://127.0.0.1/")
     416          self.checkReqURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
     417          self.checkReqURI("http://127.0.0.1/sp%E4m", SCRIPT_NAME="/sp\xe4m")
     418          self.checkReqURI("http://127.0.0.1/spammity/spam",
     419              SCRIPT_NAME="/spammity", PATH_INFO="/spam")
     420          self.checkReqURI("http://127.0.0.1/spammity/sp%E4m",
     421              SCRIPT_NAME="/spammity", PATH_INFO="/sp\xe4m")
     422          self.checkReqURI("http://127.0.0.1/spammity/spam;ham",
     423              SCRIPT_NAME="/spammity", PATH_INFO="/spam;ham")
     424          self.checkReqURI("http://127.0.0.1/spammity/spam;cookie=1234,5678",
     425              SCRIPT_NAME="/spammity", PATH_INFO="/spam;cookie=1234,5678")
     426          self.checkReqURI("http://127.0.0.1/spammity/spam?say=ni",
     427              SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
     428          self.checkReqURI("http://127.0.0.1/spammity/spam?s%E4y=ni",
     429              SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="s%E4y=ni")
     430          self.checkReqURI("http://127.0.0.1/spammity/spam", 0,
     431              SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
     432  
     433      def testFileWrapper(self):
     434          self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10])
     435  
     436      def testHopByHop(self):
     437          for hop in (
     438              "Connection Keep-Alive Proxy-Authenticate Proxy-Authorization "
     439              "TE Trailers Transfer-Encoding Upgrade"
     440          ).split():
     441              for alt in hop, hop.title(), hop.upper(), hop.lower():
     442                  self.assertTrue(util.is_hop_by_hop(alt))
     443  
     444          # Not comprehensive, just a few random header names
     445          for hop in (
     446              "Accept Cache-Control Date Pragma Trailer Via Warning"
     447          ).split():
     448              for alt in hop, hop.title(), hop.upper(), hop.lower():
     449                  self.assertFalse(util.is_hop_by_hop(alt))
     450  
     451  class ESC[4;38;5;81mHeaderTests(ESC[4;38;5;149mTestCase):
     452  
     453      def testMappingInterface(self):
     454          test = [('x','y')]
     455          self.assertEqual(len(Headers()), 0)
     456          self.assertEqual(len(Headers([])),0)
     457          self.assertEqual(len(Headers(test[:])),1)
     458          self.assertEqual(Headers(test[:]).keys(), ['x'])
     459          self.assertEqual(Headers(test[:]).values(), ['y'])
     460          self.assertEqual(Headers(test[:]).items(), test)
     461          self.assertIsNot(Headers(test).items(), test)  # must be copy!
     462  
     463          h = Headers()
     464          del h['foo']   # should not raise an error
     465  
     466          h['Foo'] = 'bar'
     467          for m in h.__contains__, h.get, h.get_all, h.__getitem__:
     468              self.assertTrue(m('foo'))
     469              self.assertTrue(m('Foo'))
     470              self.assertTrue(m('FOO'))
     471              self.assertFalse(m('bar'))
     472  
     473          self.assertEqual(h['foo'],'bar')
     474          h['foo'] = 'baz'
     475          self.assertEqual(h['FOO'],'baz')
     476          self.assertEqual(h.get_all('foo'),['baz'])
     477  
     478          self.assertEqual(h.get("foo","whee"), "baz")
     479          self.assertEqual(h.get("zoo","whee"), "whee")
     480          self.assertEqual(h.setdefault("foo","whee"), "baz")
     481          self.assertEqual(h.setdefault("zoo","whee"), "whee")
     482          self.assertEqual(h["foo"],"baz")
     483          self.assertEqual(h["zoo"],"whee")
     484  
     485      def testRequireList(self):
     486          self.assertRaises(TypeError, Headers, "foo")
     487  
     488      def testExtras(self):
     489          h = Headers()
     490          self.assertEqual(str(h),'\r\n')
     491  
     492          h.add_header('foo','bar',baz="spam")
     493          self.assertEqual(h['foo'], 'bar; baz="spam"')
     494          self.assertEqual(str(h),'foo: bar; baz="spam"\r\n\r\n')
     495  
     496          h.add_header('Foo','bar',cheese=None)
     497          self.assertEqual(h.get_all('foo'),
     498              ['bar; baz="spam"', 'bar; cheese'])
     499  
     500          self.assertEqual(str(h),
     501              'foo: bar; baz="spam"\r\n'
     502              'Foo: bar; cheese\r\n'
     503              '\r\n'
     504          )
     505  
     506  class ESC[4;38;5;81mErrorHandler(ESC[4;38;5;149mBaseCGIHandler):
     507      """Simple handler subclass for testing BaseHandler"""
     508  
     509      # BaseHandler records the OS environment at import time, but envvars
     510      # might have been changed later by other tests, which trips up
     511      # HandlerTests.testEnviron().
     512      os_environ = dict(os.environ.items())
     513  
     514      def __init__(self,**kw):
     515          setup_testing_defaults(kw)
     516          BaseCGIHandler.__init__(
     517              self, BytesIO(), BytesIO(), StringIO(), kw,
     518              multithread=True, multiprocess=True
     519          )
     520  
     521  class ESC[4;38;5;81mTestHandler(ESC[4;38;5;149mErrorHandler):
     522      """Simple handler subclass for testing BaseHandler, w/error passthru"""
     523  
     524      def handle_error(self):
     525          raise   # for testing, we want to see what's happening
     526  
     527  
     528  class ESC[4;38;5;81mHandlerTests(ESC[4;38;5;149mTestCase):
     529      # testEnviron() can produce long error message
     530      maxDiff = 80 * 50
     531  
     532      def testEnviron(self):
     533          os_environ = {
     534              # very basic environment
     535              'HOME': '/my/home',
     536              'PATH': '/my/path',
     537              'LANG': 'fr_FR.UTF-8',
     538  
     539              # set some WSGI variables
     540              'SCRIPT_NAME': 'test_script_name',
     541              'SERVER_NAME': 'test_server_name',
     542          }
     543  
     544          with support.swap_attr(TestHandler, 'os_environ', os_environ):
     545              # override X and HOME variables
     546              handler = TestHandler(X="Y", HOME="/override/home")
     547              handler.setup_environ()
     548  
     549          # Check that wsgi_xxx attributes are copied to wsgi.xxx variables
     550          # of handler.environ
     551          for attr in ('version', 'multithread', 'multiprocess', 'run_once',
     552                       'file_wrapper'):
     553              self.assertEqual(getattr(handler, 'wsgi_' + attr),
     554                               handler.environ['wsgi.' + attr])
     555  
     556          # Test handler.environ as a dict
     557          expected = {}
     558          setup_testing_defaults(expected)
     559          # Handler inherits os_environ variables which are not overridden
     560          # by SimpleHandler.add_cgi_vars() (SimpleHandler.base_env)
     561          for key, value in os_environ.items():
     562              if key not in expected:
     563                  expected[key] = value
     564          expected.update({
     565              # X doesn't exist in os_environ
     566              "X": "Y",
     567              # HOME is overridden by TestHandler
     568              'HOME': "/override/home",
     569  
     570              # overridden by setup_testing_defaults()
     571              "SCRIPT_NAME": "",
     572              "SERVER_NAME": "127.0.0.1",
     573  
     574              # set by BaseHandler.setup_environ()
     575              'wsgi.input': handler.get_stdin(),
     576              'wsgi.errors': handler.get_stderr(),
     577              'wsgi.version': (1, 0),
     578              'wsgi.run_once': False,
     579              'wsgi.url_scheme': 'http',
     580              'wsgi.multithread': True,
     581              'wsgi.multiprocess': True,
     582              'wsgi.file_wrapper': util.FileWrapper,
     583          })
     584          self.assertDictEqual(handler.environ, expected)
     585  
     586      def testCGIEnviron(self):
     587          h = BaseCGIHandler(None,None,None,{})
     588          h.setup_environ()
     589          for key in 'wsgi.url_scheme', 'wsgi.input', 'wsgi.errors':
     590              self.assertIn(key, h.environ)
     591  
     592      def testScheme(self):
     593          h=TestHandler(HTTPS="on"); h.setup_environ()
     594          self.assertEqual(h.environ['wsgi.url_scheme'],'https')
     595          h=TestHandler(); h.setup_environ()
     596          self.assertEqual(h.environ['wsgi.url_scheme'],'http')
     597  
     598      def testAbstractMethods(self):
     599          h = BaseHandler()
     600          for name in [
     601              '_flush','get_stdin','get_stderr','add_cgi_vars'
     602          ]:
     603              self.assertRaises(NotImplementedError, getattr(h,name))
     604          self.assertRaises(NotImplementedError, h._write, "test")
     605  
     606      def testContentLength(self):
     607          # Demo one reason iteration is better than write()...  ;)
     608  
     609          def trivial_app1(e,s):
     610              s('200 OK',[])
     611              return [e['wsgi.url_scheme'].encode('iso-8859-1')]
     612  
     613          def trivial_app2(e,s):
     614              s('200 OK',[])(e['wsgi.url_scheme'].encode('iso-8859-1'))
     615              return []
     616  
     617          def trivial_app3(e,s):
     618              s('200 OK',[])
     619              return ['\u0442\u0435\u0441\u0442'.encode("utf-8")]
     620  
     621          def trivial_app4(e,s):
     622              # Simulate a response to a HEAD request
     623              s('200 OK',[('Content-Length', '12345')])
     624              return []
     625  
     626          h = TestHandler()
     627          h.run(trivial_app1)
     628          self.assertEqual(h.stdout.getvalue(),
     629              ("Status: 200 OK\r\n"
     630              "Content-Length: 4\r\n"
     631              "\r\n"
     632              "http").encode("iso-8859-1"))
     633  
     634          h = TestHandler()
     635          h.run(trivial_app2)
     636          self.assertEqual(h.stdout.getvalue(),
     637              ("Status: 200 OK\r\n"
     638              "\r\n"
     639              "http").encode("iso-8859-1"))
     640  
     641          h = TestHandler()
     642          h.run(trivial_app3)
     643          self.assertEqual(h.stdout.getvalue(),
     644              b'Status: 200 OK\r\n'
     645              b'Content-Length: 8\r\n'
     646              b'\r\n'
     647              b'\xd1\x82\xd0\xb5\xd1\x81\xd1\x82')
     648  
     649          h = TestHandler()
     650          h.run(trivial_app4)
     651          self.assertEqual(h.stdout.getvalue(),
     652              b'Status: 200 OK\r\n'
     653              b'Content-Length: 12345\r\n'
     654              b'\r\n')
     655  
     656      def testBasicErrorOutput(self):
     657  
     658          def non_error_app(e,s):
     659              s('200 OK',[])
     660              return []
     661  
     662          def error_app(e,s):
     663              raise AssertionError("This should be caught by handler")
     664  
     665          h = ErrorHandler()
     666          h.run(non_error_app)
     667          self.assertEqual(h.stdout.getvalue(),
     668              ("Status: 200 OK\r\n"
     669              "Content-Length: 0\r\n"
     670              "\r\n").encode("iso-8859-1"))
     671          self.assertEqual(h.stderr.getvalue(),"")
     672  
     673          h = ErrorHandler()
     674          h.run(error_app)
     675          self.assertEqual(h.stdout.getvalue(),
     676              ("Status: %s\r\n"
     677              "Content-Type: text/plain\r\n"
     678              "Content-Length: %d\r\n"
     679              "\r\n" % (h.error_status,len(h.error_body))).encode('iso-8859-1')
     680              + h.error_body)
     681  
     682          self.assertIn("AssertionError", h.stderr.getvalue())
     683  
     684      def testErrorAfterOutput(self):
     685          MSG = b"Some output has been sent"
     686          def error_app(e,s):
     687              s("200 OK",[])(MSG)
     688              raise AssertionError("This should be caught by handler")
     689  
     690          h = ErrorHandler()
     691          h.run(error_app)
     692          self.assertEqual(h.stdout.getvalue(),
     693              ("Status: 200 OK\r\n"
     694              "\r\n".encode("iso-8859-1")+MSG))
     695          self.assertIn("AssertionError", h.stderr.getvalue())
     696  
     697      def testHeaderFormats(self):
     698  
     699          def non_error_app(e,s):
     700              s('200 OK',[])
     701              return []
     702  
     703          stdpat = (
     704              r"HTTP/%s 200 OK\r\n"
     705              r"Date: \w{3}, [ 0123]\d \w{3} \d{4} \d\d:\d\d:\d\d GMT\r\n"
     706              r"%s" r"Content-Length: 0\r\n" r"\r\n"
     707          )
     708          shortpat = (
     709              "Status: 200 OK\r\n" "Content-Length: 0\r\n" "\r\n"
     710          ).encode("iso-8859-1")
     711  
     712          for ssw in "FooBar/1.0", None:
     713              sw = ssw and "Server: %s\r\n" % ssw or ""
     714  
     715              for version in "1.0", "1.1":
     716                  for proto in "HTTP/0.9", "HTTP/1.0", "HTTP/1.1":
     717  
     718                      h = TestHandler(SERVER_PROTOCOL=proto)
     719                      h.origin_server = False
     720                      h.http_version = version
     721                      h.server_software = ssw
     722                      h.run(non_error_app)
     723                      self.assertEqual(shortpat,h.stdout.getvalue())
     724  
     725                      h = TestHandler(SERVER_PROTOCOL=proto)
     726                      h.origin_server = True
     727                      h.http_version = version
     728                      h.server_software = ssw
     729                      h.run(non_error_app)
     730                      if proto=="HTTP/0.9":
     731                          self.assertEqual(h.stdout.getvalue(),b"")
     732                      else:
     733                          self.assertTrue(
     734                              re.match((stdpat%(version,sw)).encode("iso-8859-1"),
     735                                  h.stdout.getvalue()),
     736                              ((stdpat%(version,sw)).encode("iso-8859-1"),
     737                                  h.stdout.getvalue())
     738                          )
     739  
     740      def testBytesData(self):
     741          def app(e, s):
     742              s("200 OK", [
     743                  ("Content-Type", "text/plain; charset=utf-8"),
     744                  ])
     745              return [b"data"]
     746  
     747          h = TestHandler()
     748          h.run(app)
     749          self.assertEqual(b"Status: 200 OK\r\n"
     750              b"Content-Type: text/plain; charset=utf-8\r\n"
     751              b"Content-Length: 4\r\n"
     752              b"\r\n"
     753              b"data",
     754              h.stdout.getvalue())
     755  
     756      def testCloseOnError(self):
     757          side_effects = {'close_called': False}
     758          MSG = b"Some output has been sent"
     759          def error_app(e,s):
     760              s("200 OK",[])(MSG)
     761              class ESC[4;38;5;81mCrashyIterable(ESC[4;38;5;149mobject):
     762                  def __iter__(self):
     763                      while True:
     764                          yield b'blah'
     765                          raise AssertionError("This should be caught by handler")
     766                  def close(self):
     767                      side_effects['close_called'] = True
     768              return CrashyIterable()
     769  
     770          h = ErrorHandler()
     771          h.run(error_app)
     772          self.assertEqual(side_effects['close_called'], True)
     773  
     774      def testPartialWrite(self):
     775          written = bytearray()
     776  
     777          class ESC[4;38;5;81mPartialWriter:
     778              def write(self, b):
     779                  partial = b[:7]
     780                  written.extend(partial)
     781                  return len(partial)
     782  
     783              def flush(self):
     784                  pass
     785  
     786          environ = {"SERVER_PROTOCOL": "HTTP/1.0"}
     787          h = SimpleHandler(BytesIO(), PartialWriter(), sys.stderr, environ)
     788          msg = "should not do partial writes"
     789          with self.assertWarnsRegex(DeprecationWarning, msg):
     790              h.run(hello_app)
     791          self.assertEqual(b"HTTP/1.0 200 OK\r\n"
     792              b"Content-Type: text/plain\r\n"
     793              b"Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n"
     794              b"Content-Length: 13\r\n"
     795              b"\r\n"
     796              b"Hello, world!",
     797              written)
     798  
     799      def testClientConnectionTerminations(self):
     800          environ = {"SERVER_PROTOCOL": "HTTP/1.0"}
     801          for exception in (
     802              ConnectionAbortedError,
     803              BrokenPipeError,
     804              ConnectionResetError,
     805          ):
     806              with self.subTest(exception=exception):
     807                  class ESC[4;38;5;81mAbortingWriter:
     808                      def write(self, b):
     809                          raise exception
     810  
     811                  stderr = StringIO()
     812                  h = SimpleHandler(BytesIO(), AbortingWriter(), stderr, environ)
     813                  h.run(hello_app)
     814  
     815                  self.assertFalse(stderr.getvalue())
     816  
     817      def testDontResetInternalStateOnException(self):
     818          class ESC[4;38;5;81mCustomException(ESC[4;38;5;149mValueError):
     819              pass
     820  
     821          # We are raising CustomException here to trigger an exception
     822          # during the execution of SimpleHandler.finish_response(), so
     823          # we can easily test that the internal state of the handler is
     824          # preserved in case of an exception.
     825          class ESC[4;38;5;81mAbortingWriter:
     826              def write(self, b):
     827                  raise CustomException
     828  
     829          stderr = StringIO()
     830          environ = {"SERVER_PROTOCOL": "HTTP/1.0"}
     831          h = SimpleHandler(BytesIO(), AbortingWriter(), stderr, environ)
     832          h.run(hello_app)
     833  
     834          self.assertIn("CustomException", stderr.getvalue())
     835  
     836          # Test that the internal state of the handler is preserved.
     837          self.assertIsNotNone(h.result)
     838          self.assertIsNotNone(h.headers)
     839          self.assertIsNotNone(h.status)
     840          self.assertIsNotNone(h.environ)
     841  
     842  
     843  if __name__ == "__main__":
     844      unittest.main()