python (3.11.7)

(root)/
lib/
python3.11/
test/
audit-tests.py
       1  """This script contains the actual auditing tests.
       2  
       3  It should not be imported directly, but should be run by the test_audit
       4  module with arguments identifying each test.
       5  
       6  """
       7  
       8  import contextlib
       9  import os
      10  import sys
      11  
      12  
      13  class ESC[4;38;5;81mTestHook:
      14      """Used in standard hook tests to collect any logged events.
      15  
      16      Should be used in a with block to ensure that it has no impact
      17      after the test completes.
      18      """
      19  
      20      def __init__(self, raise_on_events=None, exc_type=RuntimeError):
      21          self.raise_on_events = raise_on_events or ()
      22          self.exc_type = exc_type
      23          self.seen = []
      24          self.closed = False
      25  
      26      def __enter__(self, *a):
      27          sys.addaudithook(self)
      28          return self
      29  
      30      def __exit__(self, *a):
      31          self.close()
      32  
      33      def close(self):
      34          self.closed = True
      35  
      36      @property
      37      def seen_events(self):
      38          return [i[0] for i in self.seen]
      39  
      40      def __call__(self, event, args):
      41          if self.closed:
      42              return
      43          self.seen.append((event, args))
      44          if event in self.raise_on_events:
      45              raise self.exc_type("saw event " + event)
      46  
      47  
      48  # Simple helpers, since we are not in unittest here
      49  def assertEqual(x, y):
      50      if x != y:
      51          raise AssertionError(f"{x!r} should equal {y!r}")
      52  
      53  
      54  def assertIn(el, series):
      55      if el not in series:
      56          raise AssertionError(f"{el!r} should be in {series!r}")
      57  
      58  
      59  def assertNotIn(el, series):
      60      if el in series:
      61          raise AssertionError(f"{el!r} should not be in {series!r}")
      62  
      63  
      64  def assertSequenceEqual(x, y):
      65      if len(x) != len(y):
      66          raise AssertionError(f"{x!r} should equal {y!r}")
      67      if any(ix != iy for ix, iy in zip(x, y)):
      68          raise AssertionError(f"{x!r} should equal {y!r}")
      69  
      70  
      71  @contextlib.contextmanager
      72  def assertRaises(ex_type):
      73      try:
      74          yield
      75          assert False, f"expected {ex_type}"
      76      except BaseException as ex:
      77          if isinstance(ex, AssertionError):
      78              raise
      79          assert type(ex) is ex_type, f"{ex} should be {ex_type}"
      80  
      81  
      82  def test_basic():
      83      with TestHook() as hook:
      84          sys.audit("test_event", 1, 2, 3)
      85          assertEqual(hook.seen[0][0], "test_event")
      86          assertEqual(hook.seen[0][1], (1, 2, 3))
      87  
      88  
      89  def test_block_add_hook():
      90      # Raising an exception should prevent a new hook from being added,
      91      # but will not propagate out.
      92      with TestHook(raise_on_events="sys.addaudithook") as hook1:
      93          with TestHook() as hook2:
      94              sys.audit("test_event")
      95              assertIn("test_event", hook1.seen_events)
      96              assertNotIn("test_event", hook2.seen_events)
      97  
      98  
      99  def test_block_add_hook_baseexception():
     100      # Raising BaseException will propagate out when adding a hook
     101      with assertRaises(BaseException):
     102          with TestHook(
     103              raise_on_events="sys.addaudithook", exc_type=BaseException
     104          ) as hook1:
     105              # Adding this next hook should raise BaseException
     106              with TestHook() as hook2:
     107                  pass
     108  
     109  
     110  def test_marshal():
     111      import marshal
     112      o = ("a", "b", "c", 1, 2, 3)
     113      payload = marshal.dumps(o)
     114  
     115      with TestHook() as hook:
     116          assertEqual(o, marshal.loads(marshal.dumps(o)))
     117  
     118          try:
     119              with open("test-marshal.bin", "wb") as f:
     120                  marshal.dump(o, f)
     121              with open("test-marshal.bin", "rb") as f:
     122                  assertEqual(o, marshal.load(f))
     123          finally:
     124              os.unlink("test-marshal.bin")
     125  
     126      actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"]
     127      assertSequenceEqual(actual, [(o, marshal.version)] * 2)
     128  
     129      actual = [a[0] for e, a in hook.seen if e == "marshal.loads"]
     130      assertSequenceEqual(actual, [payload])
     131  
     132      actual = [e for e, a in hook.seen if e == "marshal.load"]
     133      assertSequenceEqual(actual, ["marshal.load"])
     134  
     135  
     136  def test_pickle():
     137      import pickle
     138  
     139      class ESC[4;38;5;81mPicklePrint:
     140          def __reduce_ex__(self, p):
     141              return str, ("Pwned!",)
     142  
     143      payload_1 = pickle.dumps(PicklePrint())
     144      payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
     145  
     146      # Before we add the hook, ensure our malicious pickle loads
     147      assertEqual("Pwned!", pickle.loads(payload_1))
     148  
     149      with TestHook(raise_on_events="pickle.find_class") as hook:
     150          with assertRaises(RuntimeError):
     151              # With the hook enabled, loading globals is not allowed
     152              pickle.loads(payload_1)
     153          # pickles with no globals are okay
     154          pickle.loads(payload_2)
     155  
     156  
     157  def test_monkeypatch():
     158      class ESC[4;38;5;81mA:
     159          pass
     160  
     161      class ESC[4;38;5;81mB:
     162          pass
     163  
     164      class ESC[4;38;5;81mC(ESC[4;38;5;149mA):
     165          pass
     166  
     167      a = A()
     168  
     169      with TestHook() as hook:
     170          # Catch name changes
     171          C.__name__ = "X"
     172          # Catch type changes
     173          C.__bases__ = (B,)
     174          # Ensure bypassing __setattr__ is still caught
     175          type.__dict__["__bases__"].__set__(C, (B,))
     176          # Catch attribute replacement
     177          C.__init__ = B.__init__
     178          # Catch attribute addition
     179          C.new_attr = 123
     180          # Catch class changes
     181          a.__class__ = B
     182  
     183      actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
     184      assertSequenceEqual(
     185          [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
     186      )
     187  
     188  
     189  def test_open():
     190      # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
     191      try:
     192          import ssl
     193  
     194          load_dh_params = ssl.create_default_context().load_dh_params
     195      except ImportError:
     196          load_dh_params = None
     197  
     198      # Try a range of "open" functions.
     199      # All of them should fail
     200      with TestHook(raise_on_events={"open"}) as hook:
     201          for fn, *args in [
     202              (open, sys.argv[2], "r"),
     203              (open, sys.executable, "rb"),
     204              (open, 3, "wb"),
     205              (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
     206              (load_dh_params, sys.argv[2]),
     207          ]:
     208              if not fn:
     209                  continue
     210              with assertRaises(RuntimeError):
     211                  fn(*args)
     212  
     213      actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
     214      actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
     215      assertSequenceEqual(
     216          [
     217              i
     218              for i in [
     219                  (sys.argv[2], "r"),
     220                  (sys.executable, "r"),
     221                  (3, "w"),
     222                  (sys.argv[2], "w"),
     223                  (sys.argv[2], "rb") if load_dh_params else None,
     224              ]
     225              if i is not None
     226          ],
     227          actual_mode,
     228      )
     229      assertSequenceEqual([], actual_flag)
     230  
     231  
     232  def test_cantrace():
     233      traced = []
     234  
     235      def trace(frame, event, *args):
     236          if frame.f_code == TestHook.__call__.__code__:
     237              traced.append(event)
     238  
     239      old = sys.settrace(trace)
     240      try:
     241          with TestHook() as hook:
     242              # No traced call
     243              eval("1")
     244  
     245              # No traced call
     246              hook.__cantrace__ = False
     247              eval("2")
     248  
     249              # One traced call
     250              hook.__cantrace__ = True
     251              eval("3")
     252  
     253              # Two traced calls (writing to private member, eval)
     254              hook.__cantrace__ = 1
     255              eval("4")
     256  
     257              # One traced call (writing to private member)
     258              hook.__cantrace__ = 0
     259      finally:
     260          sys.settrace(old)
     261  
     262      assertSequenceEqual(["call"] * 4, traced)
     263  
     264  
     265  def test_mmap():
     266      import mmap
     267  
     268      with TestHook() as hook:
     269          mmap.mmap(-1, 8)
     270          assertEqual(hook.seen[0][1][:2], (-1, 8))
     271  
     272  
     273  def test_excepthook():
     274      def excepthook(exc_type, exc_value, exc_tb):
     275          if exc_type is not RuntimeError:
     276              sys.__excepthook__(exc_type, exc_value, exc_tb)
     277  
     278      def hook(event, args):
     279          if event == "sys.excepthook":
     280              if not isinstance(args[2], args[1]):
     281                  raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
     282              if args[0] != excepthook:
     283                  raise ValueError(f"Expected {args[0]} == {excepthook}")
     284              print(event, repr(args[2]))
     285  
     286      sys.addaudithook(hook)
     287      sys.excepthook = excepthook
     288      raise RuntimeError("fatal-error")
     289  
     290  
     291  def test_unraisablehook():
     292      from _testcapi import write_unraisable_exc
     293  
     294      def unraisablehook(hookargs):
     295          pass
     296  
     297      def hook(event, args):
     298          if event == "sys.unraisablehook":
     299              if args[0] != unraisablehook:
     300                  raise ValueError(f"Expected {args[0]} == {unraisablehook}")
     301              print(event, repr(args[1].exc_value), args[1].err_msg)
     302  
     303      sys.addaudithook(hook)
     304      sys.unraisablehook = unraisablehook
     305      write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
     306  
     307  
     308  def test_winreg():
     309      from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
     310  
     311      def hook(event, args):
     312          if not event.startswith("winreg."):
     313              return
     314          print(event, *args)
     315  
     316      sys.addaudithook(hook)
     317  
     318      k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
     319      EnumKey(k, 0)
     320      try:
     321          EnumKey(k, 10000)
     322      except OSError:
     323          pass
     324      else:
     325          raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
     326  
     327      kv = k.Detach()
     328      CloseKey(kv)
     329  
     330  
     331  def test_socket():
     332      import socket
     333  
     334      def hook(event, args):
     335          if event.startswith("socket."):
     336              print(event, *args)
     337  
     338      sys.addaudithook(hook)
     339  
     340      socket.gethostname()
     341  
     342      # Don't care if this fails, we just want the audit message
     343      sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     344      try:
     345          # Don't care if this fails, we just want the audit message
     346          sock.bind(('127.0.0.1', 8080))
     347      except Exception:
     348          pass
     349      finally:
     350          sock.close()
     351  
     352  
     353  def test_gc():
     354      import gc
     355  
     356      def hook(event, args):
     357          if event.startswith("gc."):
     358              print(event, *args)
     359  
     360      sys.addaudithook(hook)
     361  
     362      gc.get_objects(generation=1)
     363  
     364      x = object()
     365      y = [x]
     366  
     367      gc.get_referrers(x)
     368      gc.get_referents(y)
     369  
     370  
     371  def test_http_client():
     372      import http.client
     373  
     374      def hook(event, args):
     375          if event.startswith("http.client."):
     376              print(event, *args[1:])
     377  
     378      sys.addaudithook(hook)
     379  
     380      conn = http.client.HTTPConnection('www.python.org')
     381      try:
     382          conn.request('GET', '/')
     383      except OSError:
     384          print('http.client.send', '[cannot send]')
     385      finally:
     386          conn.close()
     387  
     388  
     389  def test_sqlite3():
     390      import sqlite3
     391  
     392      def hook(event, *args):
     393          if event.startswith("sqlite3."):
     394              print(event, *args)
     395  
     396      sys.addaudithook(hook)
     397      cx1 = sqlite3.connect(":memory:")
     398      cx2 = sqlite3.Connection(":memory:")
     399  
     400      # Configured without --enable-loadable-sqlite-extensions
     401      if hasattr(sqlite3.Connection, "enable_load_extension"):
     402          cx1.enable_load_extension(False)
     403          try:
     404              cx1.load_extension("test")
     405          except sqlite3.OperationalError:
     406              pass
     407          else:
     408              raise RuntimeError("Expected sqlite3.load_extension to fail")
     409  
     410  
     411  def test_sys_getframe():
     412      import sys
     413  
     414      def hook(event, args):
     415          if event.startswith("sys."):
     416              print(event, args[0].f_code.co_name)
     417  
     418      sys.addaudithook(hook)
     419      sys._getframe()
     420  
     421  
     422  def test_syslog():
     423      import syslog
     424  
     425      def hook(event, args):
     426          if event.startswith("syslog."):
     427              print(event, *args)
     428  
     429      sys.addaudithook(hook)
     430      syslog.openlog('python')
     431      syslog.syslog('test')
     432      syslog.setlogmask(syslog.LOG_DEBUG)
     433      syslog.closelog()
     434      # implicit open
     435      syslog.syslog('test2')
     436      # open with default ident
     437      syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0)
     438      sys.argv = None
     439      syslog.openlog()
     440      syslog.closelog()
     441  
     442  
     443  def test_not_in_gc():
     444      import gc
     445  
     446      hook = lambda *a: None
     447      sys.addaudithook(hook)
     448  
     449      for o in gc.get_objects():
     450          if isinstance(o, list):
     451              assert hook not in o
     452  
     453  
     454  if __name__ == "__main__":
     455      from test.support import suppress_msvcrt_asserts
     456  
     457      suppress_msvcrt_asserts()
     458  
     459      test = sys.argv[1]
     460      globals()[test]()