1 import socket
2 import asyncio
3 import sys
4 import unittest
5
6 from asyncio import proactor_events
7 from itertools import cycle, islice
8 from unittest.mock import patch, Mock
9 from test.test_asyncio import utils as test_utils
10 from test import support
11 from test.support import socket_helper
12
13 if socket_helper.tcp_blackhole():
14 raise unittest.SkipTest('Not relevant to ProactorEventLoop')
15
16
17 def tearDownModule():
18 asyncio.set_event_loop_policy(None)
19
20
21 class ESC[4;38;5;81mMyProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
22 connected = None
23 done = None
24
25 def __init__(self, loop=None):
26 self.transport = None
27 self.state = 'INITIAL'
28 self.nbytes = 0
29 if loop is not None:
30 self.connected = loop.create_future()
31 self.done = loop.create_future()
32
33 def _assert_state(self, *expected):
34 if self.state not in expected:
35 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
36
37 def connection_made(self, transport):
38 self.transport = transport
39 self._assert_state('INITIAL')
40 self.state = 'CONNECTED'
41 if self.connected:
42 self.connected.set_result(None)
43 transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
44
45 def data_received(self, data):
46 self._assert_state('CONNECTED')
47 self.nbytes += len(data)
48
49 def eof_received(self):
50 self._assert_state('CONNECTED')
51 self.state = 'EOF'
52
53 def connection_lost(self, exc):
54 self._assert_state('CONNECTED', 'EOF')
55 self.state = 'CLOSED'
56 if self.done:
57 self.done.set_result(None)
58
59
60 class ESC[4;38;5;81mBaseSockTestsMixin:
61
62 def create_event_loop(self):
63 raise NotImplementedError
64
65 def setUp(self):
66 self.loop = self.create_event_loop()
67 self.set_event_loop(self.loop)
68 super().setUp()
69
70 def tearDown(self):
71 # just in case if we have transport close callbacks
72 if not self.loop.is_closed():
73 test_utils.run_briefly(self.loop)
74
75 self.doCleanups()
76 support.gc_collect()
77 super().tearDown()
78
79 def _basetest_sock_client_ops(self, httpd, sock):
80 if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
81 # in debug mode, socket operations must fail
82 # if the socket is not in blocking mode
83 self.loop.set_debug(True)
84 sock.setblocking(True)
85 with self.assertRaises(ValueError):
86 self.loop.run_until_complete(
87 self.loop.sock_connect(sock, httpd.address))
88 with self.assertRaises(ValueError):
89 self.loop.run_until_complete(
90 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
91 with self.assertRaises(ValueError):
92 self.loop.run_until_complete(
93 self.loop.sock_recv(sock, 1024))
94 with self.assertRaises(ValueError):
95 self.loop.run_until_complete(
96 self.loop.sock_recv_into(sock, bytearray()))
97 with self.assertRaises(ValueError):
98 self.loop.run_until_complete(
99 self.loop.sock_accept(sock))
100
101 # test in non-blocking mode
102 sock.setblocking(False)
103 self.loop.run_until_complete(
104 self.loop.sock_connect(sock, httpd.address))
105 self.loop.run_until_complete(
106 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
107 data = self.loop.run_until_complete(
108 self.loop.sock_recv(sock, 1024))
109 # consume data
110 self.loop.run_until_complete(
111 self.loop.sock_recv(sock, 1024))
112 sock.close()
113 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
114
115 def _basetest_sock_recv_into(self, httpd, sock):
116 # same as _basetest_sock_client_ops, but using sock_recv_into
117 sock.setblocking(False)
118 self.loop.run_until_complete(
119 self.loop.sock_connect(sock, httpd.address))
120 self.loop.run_until_complete(
121 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
122 data = bytearray(1024)
123 with memoryview(data) as buf:
124 nbytes = self.loop.run_until_complete(
125 self.loop.sock_recv_into(sock, buf[:1024]))
126 # consume data
127 self.loop.run_until_complete(
128 self.loop.sock_recv_into(sock, buf[nbytes:]))
129 sock.close()
130 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
131
132 def test_sock_client_ops(self):
133 with test_utils.run_test_server() as httpd:
134 sock = socket.socket()
135 self._basetest_sock_client_ops(httpd, sock)
136 sock = socket.socket()
137 self._basetest_sock_recv_into(httpd, sock)
138
139 async def _basetest_sock_recv_racing(self, httpd, sock):
140 sock.setblocking(False)
141 await self.loop.sock_connect(sock, httpd.address)
142
143 task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
144 await asyncio.sleep(0)
145 task.cancel()
146
147 asyncio.create_task(
148 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
149 data = await self.loop.sock_recv(sock, 1024)
150 # consume data
151 await self.loop.sock_recv(sock, 1024)
152
153 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
154
155 async def _basetest_sock_recv_into_racing(self, httpd, sock):
156 sock.setblocking(False)
157 await self.loop.sock_connect(sock, httpd.address)
158
159 data = bytearray(1024)
160 with memoryview(data) as buf:
161 task = asyncio.create_task(
162 self.loop.sock_recv_into(sock, buf[:1024]))
163 await asyncio.sleep(0)
164 task.cancel()
165
166 task = asyncio.create_task(
167 self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
168 nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
169 # consume data
170 await self.loop.sock_recv_into(sock, buf[nbytes:])
171 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
172
173 await task
174
175 async def _basetest_sock_send_racing(self, listener, sock):
176 listener.bind(('127.0.0.1', 0))
177 listener.listen(1)
178
179 # make connection
180 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
181 sock.setblocking(False)
182 task = asyncio.create_task(
183 self.loop.sock_connect(sock, listener.getsockname()))
184 await asyncio.sleep(0)
185 server = listener.accept()[0]
186 server.setblocking(False)
187
188 with server:
189 await task
190
191 # fill the buffer until sending 5 chars would block
192 size = 8192
193 while size >= 4:
194 with self.assertRaises(BlockingIOError):
195 while True:
196 sock.send(b' ' * size)
197 size = int(size / 2)
198
199 # cancel a blocked sock_sendall
200 task = asyncio.create_task(
201 self.loop.sock_sendall(sock, b'hello'))
202 await asyncio.sleep(0)
203 task.cancel()
204
205 # receive everything that is not a space
206 async def recv_all():
207 rv = b''
208 while True:
209 buf = await self.loop.sock_recv(server, 8192)
210 if not buf:
211 return rv
212 rv += buf.strip()
213 task = asyncio.create_task(recv_all())
214
215 # immediately make another sock_sendall call
216 await self.loop.sock_sendall(sock, b'world')
217 sock.shutdown(socket.SHUT_WR)
218 data = await task
219 # ProactorEventLoop could deliver hello, so endswith is necessary
220 self.assertTrue(data.endswith(b'world'))
221
222 # After the first connect attempt before the listener is ready,
223 # the socket needs time to "recover" to make the next connect call.
224 # On Linux, a second retry will do. On Windows, the waiting time is
225 # unpredictable; and on FreeBSD the socket may never come back
226 # because it's a loopback address. Here we'll just retry for a few
227 # times, and have to skip the test if it's not working. See also:
228 # https://stackoverflow.com/a/54437602/3316267
229 # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html
230 async def _basetest_sock_connect_racing(self, listener, sock):
231 listener.bind(('127.0.0.1', 0))
232 addr = listener.getsockname()
233 sock.setblocking(False)
234
235 task = asyncio.create_task(self.loop.sock_connect(sock, addr))
236 await asyncio.sleep(0)
237 task.cancel()
238
239 listener.listen(1)
240
241 skip_reason = "Max retries reached"
242 for i in range(128):
243 try:
244 await self.loop.sock_connect(sock, addr)
245 except ConnectionRefusedError as e:
246 skip_reason = e
247 except OSError as e:
248 skip_reason = e
249
250 # Retry only for this error:
251 # [WinError 10022] An invalid argument was supplied
252 if getattr(e, 'winerror', 0) != 10022:
253 break
254 else:
255 # success
256 return
257
258 self.skipTest(skip_reason)
259
260 def test_sock_client_racing(self):
261 with test_utils.run_test_server() as httpd:
262 sock = socket.socket()
263 with sock:
264 self.loop.run_until_complete(asyncio.wait_for(
265 self._basetest_sock_recv_racing(httpd, sock), 10))
266 sock = socket.socket()
267 with sock:
268 self.loop.run_until_complete(asyncio.wait_for(
269 self._basetest_sock_recv_into_racing(httpd, sock), 10))
270 listener = socket.socket()
271 sock = socket.socket()
272 with listener, sock:
273 self.loop.run_until_complete(asyncio.wait_for(
274 self._basetest_sock_send_racing(listener, sock), 10))
275
276 def test_sock_client_connect_racing(self):
277 listener = socket.socket()
278 sock = socket.socket()
279 with listener, sock:
280 self.loop.run_until_complete(asyncio.wait_for(
281 self._basetest_sock_connect_racing(listener, sock), 10))
282
283 async def _basetest_huge_content(self, address):
284 sock = socket.socket()
285 sock.setblocking(False)
286 DATA_SIZE = 10_000_00
287
288 chunk = b'0123456789' * (DATA_SIZE // 10)
289
290 await self.loop.sock_connect(sock, address)
291 await self.loop.sock_sendall(sock,
292 (b'POST /loop HTTP/1.0\r\n' +
293 b'Content-Length: %d\r\n' % DATA_SIZE +
294 b'\r\n'))
295
296 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
297
298 data = await self.loop.sock_recv(sock, DATA_SIZE)
299 # HTTP headers size is less than MTU,
300 # they are sent by the first packet always
301 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
302 while data.find(b'\r\n\r\n') == -1:
303 data += await self.loop.sock_recv(sock, DATA_SIZE)
304 # Strip headers
305 headers = data[:data.index(b'\r\n\r\n') + 4]
306 data = data[len(headers):]
307
308 size = DATA_SIZE
309 checker = cycle(b'0123456789')
310
311 expected = bytes(islice(checker, len(data)))
312 self.assertEqual(data, expected)
313 size -= len(data)
314
315 while True:
316 data = await self.loop.sock_recv(sock, DATA_SIZE)
317 if not data:
318 break
319 expected = bytes(islice(checker, len(data)))
320 self.assertEqual(data, expected)
321 size -= len(data)
322 self.assertEqual(size, 0)
323
324 await task
325 sock.close()
326
327 def test_huge_content(self):
328 with test_utils.run_test_server() as httpd:
329 self.loop.run_until_complete(
330 self._basetest_huge_content(httpd.address))
331
332 async def _basetest_huge_content_recvinto(self, address):
333 sock = socket.socket()
334 sock.setblocking(False)
335 DATA_SIZE = 10_000_00
336
337 chunk = b'0123456789' * (DATA_SIZE // 10)
338
339 await self.loop.sock_connect(sock, address)
340 await self.loop.sock_sendall(sock,
341 (b'POST /loop HTTP/1.0\r\n' +
342 b'Content-Length: %d\r\n' % DATA_SIZE +
343 b'\r\n'))
344
345 task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
346
347 array = bytearray(DATA_SIZE)
348 buf = memoryview(array)
349
350 nbytes = await self.loop.sock_recv_into(sock, buf)
351 data = bytes(buf[:nbytes])
352 # HTTP headers size is less than MTU,
353 # they are sent by the first packet always
354 self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
355 while data.find(b'\r\n\r\n') == -1:
356 nbytes = await self.loop.sock_recv_into(sock, buf)
357 data = bytes(buf[:nbytes])
358 # Strip headers
359 headers = data[:data.index(b'\r\n\r\n') + 4]
360 data = data[len(headers):]
361
362 size = DATA_SIZE
363 checker = cycle(b'0123456789')
364
365 expected = bytes(islice(checker, len(data)))
366 self.assertEqual(data, expected)
367 size -= len(data)
368
369 while True:
370 nbytes = await self.loop.sock_recv_into(sock, buf)
371 data = buf[:nbytes]
372 if not data:
373 break
374 expected = bytes(islice(checker, len(data)))
375 self.assertEqual(data, expected)
376 size -= len(data)
377 self.assertEqual(size, 0)
378
379 await task
380 sock.close()
381
382 def test_huge_content_recvinto(self):
383 with test_utils.run_test_server() as httpd:
384 self.loop.run_until_complete(
385 self._basetest_huge_content_recvinto(httpd.address))
386
387 async def _basetest_datagram_recvfrom(self, server_address):
388 # Happy path, sock.sendto() returns immediately
389 data = b'\x01' * 4096
390 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
391 sock.setblocking(False)
392 await self.loop.sock_sendto(sock, data, server_address)
393 received_data, from_addr = await self.loop.sock_recvfrom(
394 sock, 4096)
395 self.assertEqual(received_data, data)
396 self.assertEqual(from_addr, server_address)
397
398 def test_recvfrom(self):
399 with test_utils.run_udp_echo_server() as server_address:
400 self.loop.run_until_complete(
401 self._basetest_datagram_recvfrom(server_address))
402
403 async def _basetest_datagram_recvfrom_into(self, server_address):
404 # Happy path, sock.sendto() returns immediately
405 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
406 sock.setblocking(False)
407
408 buf = bytearray(4096)
409 data = b'\x01' * 4096
410 await self.loop.sock_sendto(sock, data, server_address)
411 num_bytes, from_addr = await self.loop.sock_recvfrom_into(
412 sock, buf)
413 self.assertEqual(num_bytes, 4096)
414 self.assertEqual(buf, data)
415 self.assertEqual(from_addr, server_address)
416
417 buf = bytearray(8192)
418 await self.loop.sock_sendto(sock, data, server_address)
419 num_bytes, from_addr = await self.loop.sock_recvfrom_into(
420 sock, buf, 4096)
421 self.assertEqual(num_bytes, 4096)
422 self.assertEqual(buf[:4096], data[:4096])
423 self.assertEqual(from_addr, server_address)
424
425 def test_recvfrom_into(self):
426 with test_utils.run_udp_echo_server() as server_address:
427 self.loop.run_until_complete(
428 self._basetest_datagram_recvfrom_into(server_address))
429
430 async def _basetest_datagram_sendto_blocking(self, server_address):
431 # Sad path, sock.sendto() raises BlockingIOError
432 # This involves patching sock.sendto() to raise BlockingIOError but
433 # sendto() is not used by the proactor event loop
434 data = b'\x01' * 4096
435 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
436 sock.setblocking(False)
437 mock_sock = Mock(sock)
438 mock_sock.gettimeout = sock.gettimeout
439 mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
440 mock_sock.fileno = sock.fileno
441 self.loop.call_soon(
442 lambda: setattr(mock_sock, 'sendto', sock.sendto)
443 )
444 await self.loop.sock_sendto(mock_sock, data, server_address)
445
446 received_data, from_addr = await self.loop.sock_recvfrom(
447 sock, 4096)
448 self.assertEqual(received_data, data)
449 self.assertEqual(from_addr, server_address)
450
451 def test_sendto_blocking(self):
452 if sys.platform == 'win32':
453 if isinstance(self.loop, asyncio.ProactorEventLoop):
454 raise unittest.SkipTest('Not relevant to ProactorEventLoop')
455
456 with test_utils.run_udp_echo_server() as server_address:
457 self.loop.run_until_complete(
458 self._basetest_datagram_sendto_blocking(server_address))
459
460 @socket_helper.skip_unless_bind_unix_socket
461 def test_unix_sock_client_ops(self):
462 with test_utils.run_test_unix_server() as httpd:
463 sock = socket.socket(socket.AF_UNIX)
464 self._basetest_sock_client_ops(httpd, sock)
465 sock = socket.socket(socket.AF_UNIX)
466 self._basetest_sock_recv_into(httpd, sock)
467
468 def test_sock_client_fail(self):
469 # Make sure that we will get an unused port
470 address = None
471 try:
472 s = socket.socket()
473 s.bind(('127.0.0.1', 0))
474 address = s.getsockname()
475 finally:
476 s.close()
477
478 sock = socket.socket()
479 sock.setblocking(False)
480 with self.assertRaises(ConnectionRefusedError):
481 self.loop.run_until_complete(
482 self.loop.sock_connect(sock, address))
483 sock.close()
484
485 def test_sock_accept(self):
486 listener = socket.socket()
487 listener.setblocking(False)
488 listener.bind(('127.0.0.1', 0))
489 listener.listen(1)
490 client = socket.socket()
491 client.connect(listener.getsockname())
492
493 f = self.loop.sock_accept(listener)
494 conn, addr = self.loop.run_until_complete(f)
495 self.assertEqual(conn.gettimeout(), 0)
496 self.assertEqual(addr, client.getsockname())
497 self.assertEqual(client.getpeername(), listener.getsockname())
498 client.close()
499 conn.close()
500 listener.close()
501
502 def test_cancel_sock_accept(self):
503 listener = socket.socket()
504 listener.setblocking(False)
505 listener.bind(('127.0.0.1', 0))
506 listener.listen(1)
507 sockaddr = listener.getsockname()
508 f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
509 with self.assertRaises(asyncio.TimeoutError):
510 self.loop.run_until_complete(f)
511
512 listener.close()
513 client = socket.socket()
514 client.setblocking(False)
515 f = self.loop.sock_connect(client, sockaddr)
516 with self.assertRaises(ConnectionRefusedError):
517 self.loop.run_until_complete(f)
518
519 client.close()
520
521 def test_create_connection_sock(self):
522 with test_utils.run_test_server() as httpd:
523 sock = None
524 infos = self.loop.run_until_complete(
525 self.loop.getaddrinfo(
526 *httpd.address, type=socket.SOCK_STREAM))
527 for family, type, proto, cname, address in infos:
528 try:
529 sock = socket.socket(family=family, type=type, proto=proto)
530 sock.setblocking(False)
531 self.loop.run_until_complete(
532 self.loop.sock_connect(sock, address))
533 except BaseException:
534 pass
535 else:
536 break
537 else:
538 self.fail('Can not create socket.')
539
540 f = self.loop.create_connection(
541 lambda: MyProto(loop=self.loop), sock=sock)
542 tr, pr = self.loop.run_until_complete(f)
543 self.assertIsInstance(tr, asyncio.Transport)
544 self.assertIsInstance(pr, asyncio.Protocol)
545 self.loop.run_until_complete(pr.done)
546 self.assertGreater(pr.nbytes, 0)
547 tr.close()
548
549
550 if sys.platform == 'win32':
551
552 class ESC[4;38;5;81mSelectEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
553 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
554
555 def create_event_loop(self):
556 return asyncio.SelectorEventLoop()
557
558 class ESC[4;38;5;81mProactorEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
559 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
560
561 def create_event_loop(self):
562 return asyncio.ProactorEventLoop()
563
564 else:
565 import selectors
566
567 if hasattr(selectors, 'KqueueSelector'):
568 class ESC[4;38;5;81mKqueueEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
569 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
570
571 def create_event_loop(self):
572 return asyncio.SelectorEventLoop(
573 selectors.KqueueSelector())
574
575 if hasattr(selectors, 'EpollSelector'):
576 class ESC[4;38;5;81mEPollEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
577 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
578
579 def create_event_loop(self):
580 return asyncio.SelectorEventLoop(selectors.EpollSelector())
581
582 if hasattr(selectors, 'PollSelector'):
583 class ESC[4;38;5;81mPollEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
584 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
585
586 def create_event_loop(self):
587 return asyncio.SelectorEventLoop(selectors.PollSelector())
588
589 # Should always exist.
590 class ESC[4;38;5;81mSelectEventLoopTests(ESC[4;38;5;149mBaseSockTestsMixin,
591 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
592
593 def create_event_loop(self):
594 return asyncio.SelectorEventLoop(selectors.SelectSelector())
595
596
597 if __name__ == '__main__':
598 unittest.main()