1 import asyncio
2 import asyncio.events
3 import contextlib
4 import os
5 import pprint
6 import select
7 import socket
8 import tempfile
9 import threading
10 from test import support
11
12
13 class ESC[4;38;5;81mFunctionalTestCaseMixin:
14
15 def new_loop(self):
16 return asyncio.new_event_loop()
17
18 def run_loop_briefly(self, *, delay=0.01):
19 self.loop.run_until_complete(asyncio.sleep(delay))
20
21 def loop_exception_handler(self, loop, context):
22 self.__unhandled_exceptions.append(context)
23 self.loop.default_exception_handler(context)
24
25 def setUp(self):
26 self.loop = self.new_loop()
27 asyncio.set_event_loop(None)
28
29 self.loop.set_exception_handler(self.loop_exception_handler)
30 self.__unhandled_exceptions = []
31
32 def tearDown(self):
33 try:
34 self.loop.close()
35
36 if self.__unhandled_exceptions:
37 print('Unexpected calls to loop.call_exception_handler():')
38 pprint.pprint(self.__unhandled_exceptions)
39 self.fail('unexpected calls to loop.call_exception_handler()')
40
41 finally:
42 asyncio.set_event_loop(None)
43 self.loop = None
44
45 def tcp_server(self, server_prog, *,
46 family=socket.AF_INET,
47 addr=None,
48 timeout=support.LOOPBACK_TIMEOUT,
49 backlog=1,
50 max_clients=10):
51
52 if addr is None:
53 if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
54 with tempfile.NamedTemporaryFile() as tmp:
55 addr = tmp.name
56 else:
57 addr = ('127.0.0.1', 0)
58
59 sock = socket.create_server(addr, family=family, backlog=backlog)
60 if timeout is None:
61 raise RuntimeError('timeout is required')
62 if timeout <= 0:
63 raise RuntimeError('only blocking sockets are supported')
64 sock.settimeout(timeout)
65
66 return TestThreadedServer(
67 self, sock, server_prog, timeout, max_clients)
68
69 def tcp_client(self, client_prog,
70 family=socket.AF_INET,
71 timeout=support.LOOPBACK_TIMEOUT):
72
73 sock = socket.socket(family, socket.SOCK_STREAM)
74
75 if timeout is None:
76 raise RuntimeError('timeout is required')
77 if timeout <= 0:
78 raise RuntimeError('only blocking sockets are supported')
79 sock.settimeout(timeout)
80
81 return TestThreadedClient(
82 self, sock, client_prog, timeout)
83
84 def unix_server(self, *args, **kwargs):
85 if not hasattr(socket, 'AF_UNIX'):
86 raise NotImplementedError
87 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
88
89 def unix_client(self, *args, **kwargs):
90 if not hasattr(socket, 'AF_UNIX'):
91 raise NotImplementedError
92 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
93
94 @contextlib.contextmanager
95 def unix_sock_name(self):
96 with tempfile.TemporaryDirectory() as td:
97 fn = os.path.join(td, 'sock')
98 try:
99 yield fn
100 finally:
101 try:
102 os.unlink(fn)
103 except OSError:
104 pass
105
106 def _abort_socket_test(self, ex):
107 try:
108 self.loop.stop()
109 finally:
110 self.fail(ex)
111
112
113 ##############################################################################
114 # Socket Testing Utilities
115 ##############################################################################
116
117
118 class ESC[4;38;5;81mTestSocketWrapper:
119
120 def __init__(self, sock):
121 self.__sock = sock
122
123 def recv_all(self, n):
124 buf = b''
125 while len(buf) < n:
126 data = self.recv(n - len(buf))
127 if data == b'':
128 raise ConnectionAbortedError
129 buf += data
130 return buf
131
132 def start_tls(self, ssl_context, *,
133 server_side=False,
134 server_hostname=None):
135
136 ssl_sock = ssl_context.wrap_socket(
137 self.__sock, server_side=server_side,
138 server_hostname=server_hostname,
139 do_handshake_on_connect=False)
140
141 try:
142 ssl_sock.do_handshake()
143 except:
144 ssl_sock.close()
145 raise
146 finally:
147 self.__sock.close()
148
149 self.__sock = ssl_sock
150
151 def __getattr__(self, name):
152 return getattr(self.__sock, name)
153
154 def __repr__(self):
155 return '<{} {!r}>'.format(type(self).__name__, self.__sock)
156
157
158 class ESC[4;38;5;81mSocketThread(ESC[4;38;5;149mthreadingESC[4;38;5;149m.ESC[4;38;5;149mThread):
159
160 def stop(self):
161 self._active = False
162 self.join()
163
164 def __enter__(self):
165 self.start()
166 return self
167
168 def __exit__(self, *exc):
169 self.stop()
170
171
172 class ESC[4;38;5;81mTestThreadedClient(ESC[4;38;5;149mSocketThread):
173
174 def __init__(self, test, sock, prog, timeout):
175 threading.Thread.__init__(self, None, None, 'test-client')
176 self.daemon = True
177
178 self._timeout = timeout
179 self._sock = sock
180 self._active = True
181 self._prog = prog
182 self._test = test
183
184 def run(self):
185 try:
186 self._prog(TestSocketWrapper(self._sock))
187 except Exception as ex:
188 self._test._abort_socket_test(ex)
189
190
191 class ESC[4;38;5;81mTestThreadedServer(ESC[4;38;5;149mSocketThread):
192
193 def __init__(self, test, sock, prog, timeout, max_clients):
194 threading.Thread.__init__(self, None, None, 'test-server')
195 self.daemon = True
196
197 self._clients = 0
198 self._finished_clients = 0
199 self._max_clients = max_clients
200 self._timeout = timeout
201 self._sock = sock
202 self._active = True
203
204 self._prog = prog
205
206 self._s1, self._s2 = socket.socketpair()
207 self._s1.setblocking(False)
208
209 self._test = test
210
211 def stop(self):
212 try:
213 if self._s2 and self._s2.fileno() != -1:
214 try:
215 self._s2.send(b'stop')
216 except OSError:
217 pass
218 finally:
219 super().stop()
220
221 def run(self):
222 try:
223 with self._sock:
224 self._sock.setblocking(False)
225 self._run()
226 finally:
227 self._s1.close()
228 self._s2.close()
229
230 def _run(self):
231 while self._active:
232 if self._clients >= self._max_clients:
233 return
234
235 r, w, x = select.select(
236 [self._sock, self._s1], [], [], self._timeout)
237
238 if self._s1 in r:
239 return
240
241 if self._sock in r:
242 try:
243 conn, addr = self._sock.accept()
244 except BlockingIOError:
245 continue
246 except TimeoutError:
247 if not self._active:
248 return
249 else:
250 raise
251 else:
252 self._clients += 1
253 conn.settimeout(self._timeout)
254 try:
255 with conn:
256 self._handle_client(conn)
257 except Exception as ex:
258 self._active = False
259 try:
260 raise
261 finally:
262 self._test._abort_socket_test(ex)
263
264 def _handle_client(self, sock):
265 self._prog(TestSocketWrapper(sock))
266
267 @property
268 def addr(self):
269 return self._sock.getsockname()