1 """Tests for sendfile functionality."""
2
3 import asyncio
4 import errno
5 import os
6 import socket
7 import sys
8 import tempfile
9 import unittest
10 from asyncio import base_events
11 from asyncio import constants
12 from unittest import mock
13 from test import support
14 from test.support import os_helper
15 from test.support import socket_helper
16 from test.test_asyncio import utils as test_utils
17
18 try:
19 import ssl
20 except ImportError:
21 ssl = None
22
23
24 def tearDownModule():
25 asyncio.set_event_loop_policy(None)
26
27
28 class ESC[4;38;5;81mMySendfileProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
29
30 def __init__(self, loop=None, close_after=0):
31 self.transport = None
32 self.state = 'INITIAL'
33 self.nbytes = 0
34 if loop is not None:
35 self.connected = loop.create_future()
36 self.done = loop.create_future()
37 self.data = bytearray()
38 self.close_after = close_after
39
40 def _assert_state(self, *expected):
41 if self.state not in expected:
42 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
43
44 def connection_made(self, transport):
45 self.transport = transport
46 self._assert_state('INITIAL')
47 self.state = 'CONNECTED'
48 if self.connected:
49 self.connected.set_result(None)
50
51 def eof_received(self):
52 self._assert_state('CONNECTED')
53 self.state = 'EOF'
54
55 def connection_lost(self, exc):
56 self._assert_state('CONNECTED', 'EOF')
57 self.state = 'CLOSED'
58 if self.done:
59 self.done.set_result(None)
60
61 def data_received(self, data):
62 self._assert_state('CONNECTED')
63 self.nbytes += len(data)
64 self.data.extend(data)
65 super().data_received(data)
66 if self.close_after and self.nbytes >= self.close_after:
67 self.transport.close()
68
69
70 class ESC[4;38;5;81mMyProto(ESC[4;38;5;149masyncioESC[4;38;5;149m.ESC[4;38;5;149mProtocol):
71
72 def __init__(self, loop):
73 self.started = False
74 self.closed = False
75 self.data = bytearray()
76 self.fut = loop.create_future()
77 self.transport = None
78
79 def connection_made(self, transport):
80 self.started = True
81 self.transport = transport
82
83 def data_received(self, data):
84 self.data.extend(data)
85
86 def connection_lost(self, exc):
87 self.closed = True
88 self.fut.set_result(None)
89
90 async def wait_closed(self):
91 await self.fut
92
93
94 class ESC[4;38;5;81mSendfileBase:
95
96 # 256 KiB plus small unaligned to buffer chunk
97 # Newer versions of Windows seems to have increased its internal
98 # buffer and tries to send as much of the data as it can as it
99 # has some form of buffering for this which is less than 256KiB
100 # on newer server versions and Windows 11.
101 # So DATA should be larger than 256 KiB to make this test reliable.
102 DATA = b"x" * (1024 * 256 + 1)
103 # Reduce socket buffer size to test on relative small data sets.
104 BUF_SIZE = 4 * 1024 # 4 KiB
105
106 def create_event_loop(self):
107 raise NotImplementedError
108
109 @classmethod
110 def setUpClass(cls):
111 with open(os_helper.TESTFN, 'wb') as fp:
112 fp.write(cls.DATA)
113 super().setUpClass()
114
115 @classmethod
116 def tearDownClass(cls):
117 os_helper.unlink(os_helper.TESTFN)
118 super().tearDownClass()
119
120 def setUp(self):
121 self.file = open(os_helper.TESTFN, 'rb')
122 self.addCleanup(self.file.close)
123 self.loop = self.create_event_loop()
124 self.set_event_loop(self.loop)
125 super().setUp()
126
127 def tearDown(self):
128 # just in case if we have transport close callbacks
129 if not self.loop.is_closed():
130 test_utils.run_briefly(self.loop)
131
132 self.doCleanups()
133 support.gc_collect()
134 super().tearDown()
135
136 def run_loop(self, coro):
137 return self.loop.run_until_complete(coro)
138
139
140 class ESC[4;38;5;81mSockSendfileMixin(ESC[4;38;5;149mSendfileBase):
141
142 @classmethod
143 def setUpClass(cls):
144 cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
145 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
146 super().setUpClass()
147
148 @classmethod
149 def tearDownClass(cls):
150 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
151 super().tearDownClass()
152
153 def make_socket(self, cleanup=True):
154 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
155 sock.setblocking(False)
156 if cleanup:
157 self.addCleanup(sock.close)
158 return sock
159
160 def reduce_receive_buffer_size(self, sock):
161 # Reduce receive socket buffer size to test on relative
162 # small data sets.
163 sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
164
165 def reduce_send_buffer_size(self, sock, transport=None):
166 # Reduce send socket buffer size to test on relative small data sets.
167
168 # On macOS, SO_SNDBUF is reset by connect(). So this method
169 # should be called after the socket is connected.
170 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
171
172 if transport is not None:
173 transport.set_write_buffer_limits(high=self.BUF_SIZE)
174
175 def prepare_socksendfile(self):
176 proto = MyProto(self.loop)
177 port = socket_helper.find_unused_port()
178 srv_sock = self.make_socket(cleanup=False)
179 srv_sock.bind((socket_helper.HOST, port))
180 server = self.run_loop(self.loop.create_server(
181 lambda: proto, sock=srv_sock))
182 self.reduce_receive_buffer_size(srv_sock)
183
184 sock = self.make_socket()
185 self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
186 self.reduce_send_buffer_size(sock)
187
188 def cleanup():
189 if proto.transport is not None:
190 # can be None if the task was cancelled before
191 # connection_made callback
192 proto.transport.close()
193 self.run_loop(proto.wait_closed())
194
195 server.close()
196 self.run_loop(server.wait_closed())
197
198 self.addCleanup(cleanup)
199
200 return sock, proto
201
202 def test_sock_sendfile_success(self):
203 sock, proto = self.prepare_socksendfile()
204 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
205 sock.close()
206 self.run_loop(proto.wait_closed())
207
208 self.assertEqual(ret, len(self.DATA))
209 self.assertEqual(proto.data, self.DATA)
210 self.assertEqual(self.file.tell(), len(self.DATA))
211
212 def test_sock_sendfile_with_offset_and_count(self):
213 sock, proto = self.prepare_socksendfile()
214 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
215 1000, 2000))
216 sock.close()
217 self.run_loop(proto.wait_closed())
218
219 self.assertEqual(proto.data, self.DATA[1000:3000])
220 self.assertEqual(self.file.tell(), 3000)
221 self.assertEqual(ret, 2000)
222
223 def test_sock_sendfile_zero_size(self):
224 sock, proto = self.prepare_socksendfile()
225 with tempfile.TemporaryFile() as f:
226 ret = self.run_loop(self.loop.sock_sendfile(sock, f,
227 0, None))
228 sock.close()
229 self.run_loop(proto.wait_closed())
230
231 self.assertEqual(ret, 0)
232 self.assertEqual(self.file.tell(), 0)
233
234 def test_sock_sendfile_mix_with_regular_send(self):
235 buf = b"mix_regular_send" * (4 * 1024) # 64 KiB
236 sock, proto = self.prepare_socksendfile()
237 self.run_loop(self.loop.sock_sendall(sock, buf))
238 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
239 self.run_loop(self.loop.sock_sendall(sock, buf))
240 sock.close()
241 self.run_loop(proto.wait_closed())
242
243 self.assertEqual(ret, len(self.DATA))
244 expected = buf + self.DATA + buf
245 self.assertEqual(proto.data, expected)
246 self.assertEqual(self.file.tell(), len(self.DATA))
247
248
249 class ESC[4;38;5;81mSendfileMixin(ESC[4;38;5;149mSendfileBase):
250
251 # Note: sendfile via SSL transport is equal to sendfile fallback
252
253 def prepare_sendfile(self, *, is_ssl=False, close_after=0):
254 port = socket_helper.find_unused_port()
255 srv_proto = MySendfileProto(loop=self.loop,
256 close_after=close_after)
257 if is_ssl:
258 if not ssl:
259 self.skipTest("No ssl module")
260 srv_ctx = test_utils.simple_server_sslcontext()
261 cli_ctx = test_utils.simple_client_sslcontext()
262 else:
263 srv_ctx = None
264 cli_ctx = None
265 srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
266 srv_sock.bind((socket_helper.HOST, port))
267 server = self.run_loop(self.loop.create_server(
268 lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
269 self.reduce_receive_buffer_size(srv_sock)
270
271 if is_ssl:
272 server_hostname = socket_helper.HOST
273 else:
274 server_hostname = None
275 cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
276 cli_sock.connect((socket_helper.HOST, port))
277
278 cli_proto = MySendfileProto(loop=self.loop)
279 tr, pr = self.run_loop(self.loop.create_connection(
280 lambda: cli_proto, sock=cli_sock,
281 ssl=cli_ctx, server_hostname=server_hostname))
282 self.reduce_send_buffer_size(cli_sock, transport=tr)
283
284 def cleanup():
285 srv_proto.transport.close()
286 cli_proto.transport.close()
287 self.run_loop(srv_proto.done)
288 self.run_loop(cli_proto.done)
289
290 server.close()
291 self.run_loop(server.wait_closed())
292
293 self.addCleanup(cleanup)
294 return srv_proto, cli_proto
295
296 @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
297 def test_sendfile_not_supported(self):
298 tr, pr = self.run_loop(
299 self.loop.create_datagram_endpoint(
300 asyncio.DatagramProtocol,
301 family=socket.AF_INET))
302 try:
303 with self.assertRaisesRegex(RuntimeError, "not supported"):
304 self.run_loop(
305 self.loop.sendfile(tr, self.file))
306 self.assertEqual(0, self.file.tell())
307 finally:
308 # don't use self.addCleanup because it produces resource warning
309 tr.close()
310
311 def test_sendfile(self):
312 srv_proto, cli_proto = self.prepare_sendfile()
313 ret = self.run_loop(
314 self.loop.sendfile(cli_proto.transport, self.file))
315 cli_proto.transport.close()
316 self.run_loop(srv_proto.done)
317 self.assertEqual(ret, len(self.DATA))
318 self.assertEqual(srv_proto.nbytes, len(self.DATA))
319 self.assertEqual(srv_proto.data, self.DATA)
320 self.assertEqual(self.file.tell(), len(self.DATA))
321
322 def test_sendfile_force_fallback(self):
323 srv_proto, cli_proto = self.prepare_sendfile()
324
325 def sendfile_native(transp, file, offset, count):
326 # to raise SendfileNotAvailableError
327 return base_events.BaseEventLoop._sendfile_native(
328 self.loop, transp, file, offset, count)
329
330 self.loop._sendfile_native = sendfile_native
331
332 ret = self.run_loop(
333 self.loop.sendfile(cli_proto.transport, self.file))
334 cli_proto.transport.close()
335 self.run_loop(srv_proto.done)
336 self.assertEqual(ret, len(self.DATA))
337 self.assertEqual(srv_proto.nbytes, len(self.DATA))
338 self.assertEqual(srv_proto.data, self.DATA)
339 self.assertEqual(self.file.tell(), len(self.DATA))
340
341 def test_sendfile_force_unsupported_native(self):
342 if sys.platform == 'win32':
343 if isinstance(self.loop, asyncio.ProactorEventLoop):
344 self.skipTest("Fails on proactor event loop")
345 srv_proto, cli_proto = self.prepare_sendfile()
346
347 def sendfile_native(transp, file, offset, count):
348 # to raise SendfileNotAvailableError
349 return base_events.BaseEventLoop._sendfile_native(
350 self.loop, transp, file, offset, count)
351
352 self.loop._sendfile_native = sendfile_native
353
354 with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
355 "not supported"):
356 self.run_loop(
357 self.loop.sendfile(cli_proto.transport, self.file,
358 fallback=False))
359
360 cli_proto.transport.close()
361 self.run_loop(srv_proto.done)
362 self.assertEqual(srv_proto.nbytes, 0)
363 self.assertEqual(self.file.tell(), 0)
364
365 def test_sendfile_ssl(self):
366 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
367 ret = self.run_loop(
368 self.loop.sendfile(cli_proto.transport, self.file))
369 cli_proto.transport.close()
370 self.run_loop(srv_proto.done)
371 self.assertEqual(ret, len(self.DATA))
372 self.assertEqual(srv_proto.nbytes, len(self.DATA))
373 self.assertEqual(srv_proto.data, self.DATA)
374 self.assertEqual(self.file.tell(), len(self.DATA))
375
376 def test_sendfile_for_closing_transp(self):
377 srv_proto, cli_proto = self.prepare_sendfile()
378 cli_proto.transport.close()
379 with self.assertRaisesRegex(RuntimeError, "is closing"):
380 self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
381 self.run_loop(srv_proto.done)
382 self.assertEqual(srv_proto.nbytes, 0)
383 self.assertEqual(self.file.tell(), 0)
384
385 def test_sendfile_pre_and_post_data(self):
386 srv_proto, cli_proto = self.prepare_sendfile()
387 PREFIX = b'PREFIX__' * 1024 # 8 KiB
388 SUFFIX = b'--SUFFIX' * 1024 # 8 KiB
389 cli_proto.transport.write(PREFIX)
390 ret = self.run_loop(
391 self.loop.sendfile(cli_proto.transport, self.file))
392 cli_proto.transport.write(SUFFIX)
393 cli_proto.transport.close()
394 self.run_loop(srv_proto.done)
395 self.assertEqual(ret, len(self.DATA))
396 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
397 self.assertEqual(self.file.tell(), len(self.DATA))
398
399 def test_sendfile_ssl_pre_and_post_data(self):
400 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
401 PREFIX = b'zxcvbnm' * 1024
402 SUFFIX = b'0987654321' * 1024
403 cli_proto.transport.write(PREFIX)
404 ret = self.run_loop(
405 self.loop.sendfile(cli_proto.transport, self.file))
406 cli_proto.transport.write(SUFFIX)
407 cli_proto.transport.close()
408 self.run_loop(srv_proto.done)
409 self.assertEqual(ret, len(self.DATA))
410 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
411 self.assertEqual(self.file.tell(), len(self.DATA))
412
413 def test_sendfile_partial(self):
414 srv_proto, cli_proto = self.prepare_sendfile()
415 ret = self.run_loop(
416 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
417 cli_proto.transport.close()
418 self.run_loop(srv_proto.done)
419 self.assertEqual(ret, 100)
420 self.assertEqual(srv_proto.nbytes, 100)
421 self.assertEqual(srv_proto.data, self.DATA[1000:1100])
422 self.assertEqual(self.file.tell(), 1100)
423
424 def test_sendfile_ssl_partial(self):
425 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
426 ret = self.run_loop(
427 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
428 cli_proto.transport.close()
429 self.run_loop(srv_proto.done)
430 self.assertEqual(ret, 100)
431 self.assertEqual(srv_proto.nbytes, 100)
432 self.assertEqual(srv_proto.data, self.DATA[1000:1100])
433 self.assertEqual(self.file.tell(), 1100)
434
435 def test_sendfile_close_peer_after_receiving(self):
436 srv_proto, cli_proto = self.prepare_sendfile(
437 close_after=len(self.DATA))
438 ret = self.run_loop(
439 self.loop.sendfile(cli_proto.transport, self.file))
440 cli_proto.transport.close()
441 self.run_loop(srv_proto.done)
442 self.assertEqual(ret, len(self.DATA))
443 self.assertEqual(srv_proto.nbytes, len(self.DATA))
444 self.assertEqual(srv_proto.data, self.DATA)
445 self.assertEqual(self.file.tell(), len(self.DATA))
446
447 def test_sendfile_ssl_close_peer_after_receiving(self):
448 srv_proto, cli_proto = self.prepare_sendfile(
449 is_ssl=True, close_after=len(self.DATA))
450 ret = self.run_loop(
451 self.loop.sendfile(cli_proto.transport, self.file))
452 self.run_loop(srv_proto.done)
453 self.assertEqual(ret, len(self.DATA))
454 self.assertEqual(srv_proto.nbytes, len(self.DATA))
455 self.assertEqual(srv_proto.data, self.DATA)
456 self.assertEqual(self.file.tell(), len(self.DATA))
457
458 # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been
459 # established has no effect. Due to its age, this bug affects both Oracle
460 # Solaris as well as all other OpenSolaris forks (unless they fixed it
461 # themselves).
462 @unittest.skipIf(sys.platform.startswith('sunos'),
463 "Doesn't work on Solaris")
464 def test_sendfile_close_peer_in_the_middle_of_receiving(self):
465 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
466 with self.assertRaises(ConnectionError):
467 self.run_loop(
468 self.loop.sendfile(cli_proto.transport, self.file))
469 self.run_loop(srv_proto.done)
470
471 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
472 srv_proto.nbytes)
473 if not (sys.platform == 'win32'
474 and isinstance(self.loop, asyncio.ProactorEventLoop)):
475 # On Windows, Proactor uses transmitFile, which does not update tell()
476 self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
477 self.file.tell())
478 self.assertTrue(cli_proto.transport.is_closing())
479
480 def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
481
482 def sendfile_native(transp, file, offset, count):
483 # to raise SendfileNotAvailableError
484 return base_events.BaseEventLoop._sendfile_native(
485 self.loop, transp, file, offset, count)
486
487 self.loop._sendfile_native = sendfile_native
488
489 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
490 with self.assertRaises(ConnectionError):
491 try:
492 self.run_loop(
493 self.loop.sendfile(cli_proto.transport, self.file))
494 except OSError as e:
495 # macOS may raise OSError of EPROTOTYPE when writing to a
496 # socket that is in the process of closing down.
497 if e.errno == errno.EPROTOTYPE and sys.platform == "darwin":
498 raise ConnectionError
499 else:
500 raise
501
502 self.run_loop(srv_proto.done)
503
504 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
505 srv_proto.nbytes)
506 self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
507 self.file.tell())
508
509 @unittest.skipIf(not hasattr(os, 'sendfile'),
510 "Don't have native sendfile support")
511 def test_sendfile_prevents_bare_write(self):
512 srv_proto, cli_proto = self.prepare_sendfile()
513 fut = self.loop.create_future()
514
515 async def coro():
516 fut.set_result(None)
517 return await self.loop.sendfile(cli_proto.transport, self.file)
518
519 t = self.loop.create_task(coro())
520 self.run_loop(fut)
521 with self.assertRaisesRegex(RuntimeError,
522 "sendfile is in progress"):
523 cli_proto.transport.write(b'data')
524 ret = self.run_loop(t)
525 self.assertEqual(ret, len(self.DATA))
526
527 def test_sendfile_no_fallback_for_fallback_transport(self):
528 transport = mock.Mock()
529 transport.is_closing.side_effect = lambda: False
530 transport._sendfile_compatible = constants._SendfileMode.FALLBACK
531 with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
532 self.loop.run_until_complete(
533 self.loop.sendfile(transport, None, fallback=False))
534
535
536 class ESC[4;38;5;81mSendfileTestsBase(ESC[4;38;5;149mSendfileMixin, ESC[4;38;5;149mSockSendfileMixin):
537 pass
538
539
540 if sys.platform == 'win32':
541
542 class ESC[4;38;5;81mSelectEventLoopTests(ESC[4;38;5;149mSendfileTestsBase,
543 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
544
545 def create_event_loop(self):
546 return asyncio.SelectorEventLoop()
547
548 class ESC[4;38;5;81mProactorEventLoopTests(ESC[4;38;5;149mSendfileTestsBase,
549 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
550
551 def create_event_loop(self):
552 return asyncio.ProactorEventLoop()
553
554 else:
555 import selectors
556
557 if hasattr(selectors, 'KqueueSelector'):
558 class ESC[4;38;5;81mKqueueEventLoopTests(ESC[4;38;5;149mSendfileTestsBase,
559 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
560
561 def create_event_loop(self):
562 return asyncio.SelectorEventLoop(
563 selectors.KqueueSelector())
564
565 if hasattr(selectors, 'EpollSelector'):
566 class ESC[4;38;5;81mEPollEventLoopTests(ESC[4;38;5;149mSendfileTestsBase,
567 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
568
569 def create_event_loop(self):
570 return asyncio.SelectorEventLoop(selectors.EpollSelector())
571
572 if hasattr(selectors, 'PollSelector'):
573 class ESC[4;38;5;81mPollEventLoopTests(ESC[4;38;5;149mSendfileTestsBase,
574 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
575
576 def create_event_loop(self):
577 return asyncio.SelectorEventLoop(selectors.PollSelector())
578
579 # Should always exist.
580 class ESC[4;38;5;81mSelectEventLoopTests(ESC[4;38;5;149mSendfileTestsBase,
581 ESC[4;38;5;149mtest_utilsESC[4;38;5;149m.ESC[4;38;5;149mTestCase):
582
583 def create_event_loop(self):
584 return asyncio.SelectorEventLoop(selectors.SelectSelector())
585
586
587 if __name__ == '__main__':
588 unittest.main()