1 """This script contains the actual auditing tests.
2
3 It should not be imported directly, but should be run by the test_audit
4 module with arguments identifying each test.
5
6 """
7
8 import contextlib
9 import os
10 import sys
11
12
13 class ESC[4;38;5;81mTestHook:
14 """Used in standard hook tests to collect any logged events.
15
16 Should be used in a with block to ensure that it has no impact
17 after the test completes.
18 """
19
20 def __init__(self, raise_on_events=None, exc_type=RuntimeError):
21 self.raise_on_events = raise_on_events or ()
22 self.exc_type = exc_type
23 self.seen = []
24 self.closed = False
25
26 def __enter__(self, *a):
27 sys.addaudithook(self)
28 return self
29
30 def __exit__(self, *a):
31 self.close()
32
33 def close(self):
34 self.closed = True
35
36 @property
37 def seen_events(self):
38 return [i[0] for i in self.seen]
39
40 def __call__(self, event, args):
41 if self.closed:
42 return
43 self.seen.append((event, args))
44 if event in self.raise_on_events:
45 raise self.exc_type("saw event " + event)
46
47
48 # Simple helpers, since we are not in unittest here
49 def assertEqual(x, y):
50 if x != y:
51 raise AssertionError(f"{x!r} should equal {y!r}")
52
53
54 def assertIn(el, series):
55 if el not in series:
56 raise AssertionError(f"{el!r} should be in {series!r}")
57
58
59 def assertNotIn(el, series):
60 if el in series:
61 raise AssertionError(f"{el!r} should not be in {series!r}")
62
63
64 def assertSequenceEqual(x, y):
65 if len(x) != len(y):
66 raise AssertionError(f"{x!r} should equal {y!r}")
67 if any(ix != iy for ix, iy in zip(x, y)):
68 raise AssertionError(f"{x!r} should equal {y!r}")
69
70
71 @contextlib.contextmanager
72 def assertRaises(ex_type):
73 try:
74 yield
75 assert False, f"expected {ex_type}"
76 except BaseException as ex:
77 if isinstance(ex, AssertionError):
78 raise
79 assert type(ex) is ex_type, f"{ex} should be {ex_type}"
80
81
82 def test_basic():
83 with TestHook() as hook:
84 sys.audit("test_event", 1, 2, 3)
85 assertEqual(hook.seen[0][0], "test_event")
86 assertEqual(hook.seen[0][1], (1, 2, 3))
87
88
89 def test_block_add_hook():
90 # Raising an exception should prevent a new hook from being added,
91 # but will not propagate out.
92 with TestHook(raise_on_events="sys.addaudithook") as hook1:
93 with TestHook() as hook2:
94 sys.audit("test_event")
95 assertIn("test_event", hook1.seen_events)
96 assertNotIn("test_event", hook2.seen_events)
97
98
99 def test_block_add_hook_baseexception():
100 # Raising BaseException will propagate out when adding a hook
101 with assertRaises(BaseException):
102 with TestHook(
103 raise_on_events="sys.addaudithook", exc_type=BaseException
104 ) as hook1:
105 # Adding this next hook should raise BaseException
106 with TestHook() as hook2:
107 pass
108
109
110 def test_marshal():
111 import marshal
112 o = ("a", "b", "c", 1, 2, 3)
113 payload = marshal.dumps(o)
114
115 with TestHook() as hook:
116 assertEqual(o, marshal.loads(marshal.dumps(o)))
117
118 try:
119 with open("test-marshal.bin", "wb") as f:
120 marshal.dump(o, f)
121 with open("test-marshal.bin", "rb") as f:
122 assertEqual(o, marshal.load(f))
123 finally:
124 os.unlink("test-marshal.bin")
125
126 actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"]
127 assertSequenceEqual(actual, [(o, marshal.version)] * 2)
128
129 actual = [a[0] for e, a in hook.seen if e == "marshal.loads"]
130 assertSequenceEqual(actual, [payload])
131
132 actual = [e for e, a in hook.seen if e == "marshal.load"]
133 assertSequenceEqual(actual, ["marshal.load"])
134
135
136 def test_pickle():
137 import pickle
138
139 class ESC[4;38;5;81mPicklePrint:
140 def __reduce_ex__(self, p):
141 return str, ("Pwned!",)
142
143 payload_1 = pickle.dumps(PicklePrint())
144 payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
145
146 # Before we add the hook, ensure our malicious pickle loads
147 assertEqual("Pwned!", pickle.loads(payload_1))
148
149 with TestHook(raise_on_events="pickle.find_class") as hook:
150 with assertRaises(RuntimeError):
151 # With the hook enabled, loading globals is not allowed
152 pickle.loads(payload_1)
153 # pickles with no globals are okay
154 pickle.loads(payload_2)
155
156
157 def test_monkeypatch():
158 class ESC[4;38;5;81mA:
159 pass
160
161 class ESC[4;38;5;81mB:
162 pass
163
164 class ESC[4;38;5;81mC(ESC[4;38;5;149mA):
165 pass
166
167 a = A()
168
169 with TestHook() as hook:
170 # Catch name changes
171 C.__name__ = "X"
172 # Catch type changes
173 C.__bases__ = (B,)
174 # Ensure bypassing __setattr__ is still caught
175 type.__dict__["__bases__"].__set__(C, (B,))
176 # Catch attribute replacement
177 C.__init__ = B.__init__
178 # Catch attribute addition
179 C.new_attr = 123
180 # Catch class changes
181 a.__class__ = B
182
183 actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
184 assertSequenceEqual(
185 [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
186 )
187
188
189 def test_open():
190 # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
191 try:
192 import ssl
193
194 load_dh_params = ssl.create_default_context().load_dh_params
195 except ImportError:
196 load_dh_params = None
197
198 # Try a range of "open" functions.
199 # All of them should fail
200 with TestHook(raise_on_events={"open"}) as hook:
201 for fn, *args in [
202 (open, sys.argv[2], "r"),
203 (open, sys.executable, "rb"),
204 (open, 3, "wb"),
205 (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
206 (load_dh_params, sys.argv[2]),
207 ]:
208 if not fn:
209 continue
210 with assertRaises(RuntimeError):
211 fn(*args)
212
213 actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
214 actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
215 assertSequenceEqual(
216 [
217 i
218 for i in [
219 (sys.argv[2], "r"),
220 (sys.executable, "r"),
221 (3, "w"),
222 (sys.argv[2], "w"),
223 (sys.argv[2], "rb") if load_dh_params else None,
224 ]
225 if i is not None
226 ],
227 actual_mode,
228 )
229 assertSequenceEqual([], actual_flag)
230
231
232 def test_cantrace():
233 traced = []
234
235 def trace(frame, event, *args):
236 if frame.f_code == TestHook.__call__.__code__:
237 traced.append(event)
238
239 old = sys.settrace(trace)
240 try:
241 with TestHook() as hook:
242 # No traced call
243 eval("1")
244
245 # No traced call
246 hook.__cantrace__ = False
247 eval("2")
248
249 # One traced call
250 hook.__cantrace__ = True
251 eval("3")
252
253 # Two traced calls (writing to private member, eval)
254 hook.__cantrace__ = 1
255 eval("4")
256
257 # One traced call (writing to private member)
258 hook.__cantrace__ = 0
259 finally:
260 sys.settrace(old)
261
262 assertSequenceEqual(["call"] * 4, traced)
263
264
265 def test_mmap():
266 import mmap
267
268 with TestHook() as hook:
269 mmap.mmap(-1, 8)
270 assertEqual(hook.seen[0][1][:2], (-1, 8))
271
272
273 def test_excepthook():
274 def excepthook(exc_type, exc_value, exc_tb):
275 if exc_type is not RuntimeError:
276 sys.__excepthook__(exc_type, exc_value, exc_tb)
277
278 def hook(event, args):
279 if event == "sys.excepthook":
280 if not isinstance(args[2], args[1]):
281 raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
282 if args[0] != excepthook:
283 raise ValueError(f"Expected {args[0]} == {excepthook}")
284 print(event, repr(args[2]))
285
286 sys.addaudithook(hook)
287 sys.excepthook = excepthook
288 raise RuntimeError("fatal-error")
289
290
291 def test_unraisablehook():
292 from _testcapi import write_unraisable_exc
293
294 def unraisablehook(hookargs):
295 pass
296
297 def hook(event, args):
298 if event == "sys.unraisablehook":
299 if args[0] != unraisablehook:
300 raise ValueError(f"Expected {args[0]} == {unraisablehook}")
301 print(event, repr(args[1].exc_value), args[1].err_msg)
302
303 sys.addaudithook(hook)
304 sys.unraisablehook = unraisablehook
305 write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
306
307
308 def test_winreg():
309 from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
310
311 def hook(event, args):
312 if not event.startswith("winreg."):
313 return
314 print(event, *args)
315
316 sys.addaudithook(hook)
317
318 k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
319 EnumKey(k, 0)
320 try:
321 EnumKey(k, 10000)
322 except OSError:
323 pass
324 else:
325 raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
326
327 kv = k.Detach()
328 CloseKey(kv)
329
330
331 def test_socket():
332 import socket
333
334 def hook(event, args):
335 if event.startswith("socket."):
336 print(event, *args)
337
338 sys.addaudithook(hook)
339
340 socket.gethostname()
341
342 # Don't care if this fails, we just want the audit message
343 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
344 try:
345 # Don't care if this fails, we just want the audit message
346 sock.bind(('127.0.0.1', 8080))
347 except Exception:
348 pass
349 finally:
350 sock.close()
351
352
353 def test_gc():
354 import gc
355
356 def hook(event, args):
357 if event.startswith("gc."):
358 print(event, *args)
359
360 sys.addaudithook(hook)
361
362 gc.get_objects(generation=1)
363
364 x = object()
365 y = [x]
366
367 gc.get_referrers(x)
368 gc.get_referents(y)
369
370
371 def test_http_client():
372 import http.client
373
374 def hook(event, args):
375 if event.startswith("http.client."):
376 print(event, *args[1:])
377
378 sys.addaudithook(hook)
379
380 conn = http.client.HTTPConnection('www.python.org')
381 try:
382 conn.request('GET', '/')
383 except OSError:
384 print('http.client.send', '[cannot send]')
385 finally:
386 conn.close()
387
388
389 def test_sqlite3():
390 import sqlite3
391
392 def hook(event, *args):
393 if event.startswith("sqlite3."):
394 print(event, *args)
395
396 sys.addaudithook(hook)
397 cx1 = sqlite3.connect(":memory:")
398 cx2 = sqlite3.Connection(":memory:")
399
400 # Configured without --enable-loadable-sqlite-extensions
401 if hasattr(sqlite3.Connection, "enable_load_extension"):
402 cx1.enable_load_extension(False)
403 try:
404 cx1.load_extension("test")
405 except sqlite3.OperationalError:
406 pass
407 else:
408 raise RuntimeError("Expected sqlite3.load_extension to fail")
409
410
411 def test_sys_getframe():
412 import sys
413
414 def hook(event, args):
415 if event.startswith("sys."):
416 print(event, args[0].f_code.co_name)
417
418 sys.addaudithook(hook)
419 sys._getframe()
420
421
422 def test_sys_getframemodulename():
423 import sys
424
425 def hook(event, args):
426 if event.startswith("sys."):
427 print(event, *args)
428
429 sys.addaudithook(hook)
430 sys._getframemodulename()
431
432
433 def test_threading():
434 import _thread
435
436 def hook(event, args):
437 if event.startswith(("_thread.", "cpython.PyThreadState", "test.")):
438 print(event, args)
439
440 sys.addaudithook(hook)
441
442 lock = _thread.allocate_lock()
443 lock.acquire()
444
445 class ESC[4;38;5;81mtest_func:
446 def __repr__(self): return "<test_func>"
447 def __call__(self):
448 sys.audit("test.test_func")
449 lock.release()
450
451 i = _thread.start_new_thread(test_func(), ())
452 lock.acquire()
453
454
455 def test_threading_abort():
456 # Ensures that aborting PyThreadState_New raises the correct exception
457 import _thread
458
459 class ESC[4;38;5;81mThreadNewAbortError(ESC[4;38;5;149mException):
460 pass
461
462 def hook(event, args):
463 if event == "cpython.PyThreadState_New":
464 raise ThreadNewAbortError()
465
466 sys.addaudithook(hook)
467
468 try:
469 _thread.start_new_thread(lambda: None, ())
470 except ThreadNewAbortError:
471 # Other exceptions are raised and the test will fail
472 pass
473
474
475 def test_wmi_exec_query():
476 import _wmi
477
478 def hook(event, args):
479 if event.startswith("_wmi."):
480 print(event, args[0])
481
482 sys.addaudithook(hook)
483 _wmi.exec_query("SELECT * FROM Win32_OperatingSystem")
484
485 def test_syslog():
486 import syslog
487
488 def hook(event, args):
489 if event.startswith("syslog."):
490 print(event, *args)
491
492 sys.addaudithook(hook)
493 syslog.openlog('python')
494 syslog.syslog('test')
495 syslog.setlogmask(syslog.LOG_DEBUG)
496 syslog.closelog()
497 # implicit open
498 syslog.syslog('test2')
499 # open with default ident
500 syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0)
501 sys.argv = None
502 syslog.openlog()
503 syslog.closelog()
504
505
506 def test_not_in_gc():
507 import gc
508
509 hook = lambda *a: None
510 sys.addaudithook(hook)
511
512 for o in gc.get_objects():
513 if isinstance(o, list):
514 assert hook not in o
515
516
517 def test_sys_monitoring_register_callback():
518 import sys
519
520 def hook(event, args):
521 if event.startswith("sys.monitoring"):
522 print(event, args)
523
524 sys.addaudithook(hook)
525 sys.monitoring.register_callback(1, 1, None)
526
527
528 if __name__ == "__main__":
529 from test.support import suppress_msvcrt_asserts
530
531 suppress_msvcrt_asserts()
532
533 test = sys.argv[1]
534 globals()[test]()