1 # Wrapper module for _ssl, providing some additional facilities
2 # implemented in Python. Written by Bill Janssen.
3
4 """This module provides some more Pythonic support for SSL.
5
6 Object types:
7
8 SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10 Exceptions:
11
12 SSLError -- exception raised for I/O errors
13
14 Functions:
15
16 cert_time_to_seconds -- convert time string used for certificate
17 notBefore and notAfter functions to integer
18 seconds past the Epoch (the time values
19 returned from time.time())
20
21 get_server_certificate (addr, ssl_version, ca_certs, timeout) -- Retrieve the
22 certificate from the server at the specified
23 address and return it as a PEM-encoded string
24
25
26 Integer constants:
27
28 SSL_ERROR_ZERO_RETURN
29 SSL_ERROR_WANT_READ
30 SSL_ERROR_WANT_WRITE
31 SSL_ERROR_WANT_X509_LOOKUP
32 SSL_ERROR_SYSCALL
33 SSL_ERROR_SSL
34 SSL_ERROR_WANT_CONNECT
35
36 SSL_ERROR_EOF
37 SSL_ERROR_INVALID_ERROR_CODE
38
39 The following group define certificate requirements that one side is
40 allowing/requiring from the other side:
41
42 CERT_NONE - no certificates from the other side are required (or will
43 be looked at if provided)
44 CERT_OPTIONAL - certificates are not required, but if provided will be
45 validated, and if validation fails, the connection will
46 also fail
47 CERT_REQUIRED - certificates are required, and will be validated, and
48 if validation fails, the connection will also fail
49
50 The following constants identify various SSL protocol variants:
51
52 PROTOCOL_SSLv2
53 PROTOCOL_SSLv3
54 PROTOCOL_SSLv23
55 PROTOCOL_TLS
56 PROTOCOL_TLS_CLIENT
57 PROTOCOL_TLS_SERVER
58 PROTOCOL_TLSv1
59 PROTOCOL_TLSv1_1
60 PROTOCOL_TLSv1_2
61
62 The following constants identify various SSL alert message descriptions as per
63 http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
64
65 ALERT_DESCRIPTION_CLOSE_NOTIFY
66 ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
67 ALERT_DESCRIPTION_BAD_RECORD_MAC
68 ALERT_DESCRIPTION_RECORD_OVERFLOW
69 ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
70 ALERT_DESCRIPTION_HANDSHAKE_FAILURE
71 ALERT_DESCRIPTION_BAD_CERTIFICATE
72 ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
73 ALERT_DESCRIPTION_CERTIFICATE_REVOKED
74 ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
75 ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
76 ALERT_DESCRIPTION_ILLEGAL_PARAMETER
77 ALERT_DESCRIPTION_UNKNOWN_CA
78 ALERT_DESCRIPTION_ACCESS_DENIED
79 ALERT_DESCRIPTION_DECODE_ERROR
80 ALERT_DESCRIPTION_DECRYPT_ERROR
81 ALERT_DESCRIPTION_PROTOCOL_VERSION
82 ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
83 ALERT_DESCRIPTION_INTERNAL_ERROR
84 ALERT_DESCRIPTION_USER_CANCELLED
85 ALERT_DESCRIPTION_NO_RENEGOTIATION
86 ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
87 ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
88 ALERT_DESCRIPTION_UNRECOGNIZED_NAME
89 ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
90 ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
91 ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
92 """
93
94 import sys
95 import os
96 from collections import namedtuple
97 from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
98 from enum import _simple_enum
99
100 import _ssl # if we can't import it, let the error propagate
101
102 from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
103 from _ssl import _SSLContext, MemoryBIO, SSLSession
104 from _ssl import (
105 SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
106 SSLSyscallError, SSLEOFError, SSLCertVerificationError
107 )
108 from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
109 from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes
110 try:
111 from _ssl import RAND_egd
112 except ImportError:
113 # LibreSSL does not provide RAND_egd
114 pass
115
116
117 from _ssl import (
118 HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
119 HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
120 )
121 from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
122
123 _IntEnum._convert_(
124 '_SSLMethod', __name__,
125 lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
126 source=_ssl)
127
128 _IntFlag._convert_(
129 'Options', __name__,
130 lambda name: name.startswith('OP_'),
131 source=_ssl)
132
133 _IntEnum._convert_(
134 'AlertDescription', __name__,
135 lambda name: name.startswith('ALERT_DESCRIPTION_'),
136 source=_ssl)
137
138 _IntEnum._convert_(
139 'SSLErrorNumber', __name__,
140 lambda name: name.startswith('SSL_ERROR_'),
141 source=_ssl)
142
143 _IntFlag._convert_(
144 'VerifyFlags', __name__,
145 lambda name: name.startswith('VERIFY_'),
146 source=_ssl)
147
148 _IntEnum._convert_(
149 'VerifyMode', __name__,
150 lambda name: name.startswith('CERT_'),
151 source=_ssl)
152
153 PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS
154 _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
155
156 _SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
157
158
159 @_simple_enum(_IntEnum)
160 class ESC[4;38;5;81mTLSVersion:
161 MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
162 SSLv3 = _ssl.PROTO_SSLv3
163 TLSv1 = _ssl.PROTO_TLSv1
164 TLSv1_1 = _ssl.PROTO_TLSv1_1
165 TLSv1_2 = _ssl.PROTO_TLSv1_2
166 TLSv1_3 = _ssl.PROTO_TLSv1_3
167 MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
168
169
170 @_simple_enum(_IntEnum)
171 class ESC[4;38;5;81m_TLSContentType:
172 """Content types (record layer)
173
174 See RFC 8446, section B.1
175 """
176 CHANGE_CIPHER_SPEC = 20
177 ALERT = 21
178 HANDSHAKE = 22
179 APPLICATION_DATA = 23
180 # pseudo content types
181 HEADER = 0x100
182 INNER_CONTENT_TYPE = 0x101
183
184
185 @_simple_enum(_IntEnum)
186 class ESC[4;38;5;81m_TLSAlertType:
187 """Alert types for TLSContentType.ALERT messages
188
189 See RFC 8466, section B.2
190 """
191 CLOSE_NOTIFY = 0
192 UNEXPECTED_MESSAGE = 10
193 BAD_RECORD_MAC = 20
194 DECRYPTION_FAILED = 21
195 RECORD_OVERFLOW = 22
196 DECOMPRESSION_FAILURE = 30
197 HANDSHAKE_FAILURE = 40
198 NO_CERTIFICATE = 41
199 BAD_CERTIFICATE = 42
200 UNSUPPORTED_CERTIFICATE = 43
201 CERTIFICATE_REVOKED = 44
202 CERTIFICATE_EXPIRED = 45
203 CERTIFICATE_UNKNOWN = 46
204 ILLEGAL_PARAMETER = 47
205 UNKNOWN_CA = 48
206 ACCESS_DENIED = 49
207 DECODE_ERROR = 50
208 DECRYPT_ERROR = 51
209 EXPORT_RESTRICTION = 60
210 PROTOCOL_VERSION = 70
211 INSUFFICIENT_SECURITY = 71
212 INTERNAL_ERROR = 80
213 INAPPROPRIATE_FALLBACK = 86
214 USER_CANCELED = 90
215 NO_RENEGOTIATION = 100
216 MISSING_EXTENSION = 109
217 UNSUPPORTED_EXTENSION = 110
218 CERTIFICATE_UNOBTAINABLE = 111
219 UNRECOGNIZED_NAME = 112
220 BAD_CERTIFICATE_STATUS_RESPONSE = 113
221 BAD_CERTIFICATE_HASH_VALUE = 114
222 UNKNOWN_PSK_IDENTITY = 115
223 CERTIFICATE_REQUIRED = 116
224 NO_APPLICATION_PROTOCOL = 120
225
226
227 @_simple_enum(_IntEnum)
228 class ESC[4;38;5;81m_TLSMessageType:
229 """Message types (handshake protocol)
230
231 See RFC 8446, section B.3
232 """
233 HELLO_REQUEST = 0
234 CLIENT_HELLO = 1
235 SERVER_HELLO = 2
236 HELLO_VERIFY_REQUEST = 3
237 NEWSESSION_TICKET = 4
238 END_OF_EARLY_DATA = 5
239 HELLO_RETRY_REQUEST = 6
240 ENCRYPTED_EXTENSIONS = 8
241 CERTIFICATE = 11
242 SERVER_KEY_EXCHANGE = 12
243 CERTIFICATE_REQUEST = 13
244 SERVER_DONE = 14
245 CERTIFICATE_VERIFY = 15
246 CLIENT_KEY_EXCHANGE = 16
247 FINISHED = 20
248 CERTIFICATE_URL = 21
249 CERTIFICATE_STATUS = 22
250 SUPPLEMENTAL_DATA = 23
251 KEY_UPDATE = 24
252 NEXT_PROTO = 67
253 MESSAGE_HASH = 254
254 CHANGE_CIPHER_SPEC = 0x0101
255
256
257 if sys.platform == "win32":
258 from _ssl import enum_certificates, enum_crls
259
260 from socket import socket, SOCK_STREAM, create_connection
261 from socket import SOL_SOCKET, SO_TYPE, _GLOBAL_DEFAULT_TIMEOUT
262 import socket as _socket
263 import base64 # for DER-to-PEM translation
264 import errno
265 import warnings
266
267
268 socket_error = OSError # keep that public name in module namespace
269
270 CHANNEL_BINDING_TYPES = ['tls-unique']
271
272 HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
273
274
275 _RESTRICTED_SERVER_CIPHERS = _DEFAULT_CIPHERS
276
277 CertificateError = SSLCertVerificationError
278
279
280 def _dnsname_match(dn, hostname):
281 """Matching according to RFC 6125, section 6.4.3
282
283 - Hostnames are compared lower-case.
284 - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE).
285 - Partial wildcards like 'www*.example.org', multiple wildcards, sole
286 wildcard or wildcards in labels other then the left-most label are not
287 supported and a CertificateError is raised.
288 - A wildcard must match at least one character.
289 """
290 if not dn:
291 return False
292
293 wildcards = dn.count('*')
294 # speed up common case w/o wildcards
295 if not wildcards:
296 return dn.lower() == hostname.lower()
297
298 if wildcards > 1:
299 raise CertificateError(
300 "too many wildcards in certificate DNS name: {!r}.".format(dn))
301
302 dn_leftmost, sep, dn_remainder = dn.partition('.')
303
304 if '*' in dn_remainder:
305 # Only match wildcard in leftmost segment.
306 raise CertificateError(
307 "wildcard can only be present in the leftmost label: "
308 "{!r}.".format(dn))
309
310 if not sep:
311 # no right side
312 raise CertificateError(
313 "sole wildcard without additional labels are not support: "
314 "{!r}.".format(dn))
315
316 if dn_leftmost != '*':
317 # no partial wildcard matching
318 raise CertificateError(
319 "partial wildcards in leftmost label are not supported: "
320 "{!r}.".format(dn))
321
322 hostname_leftmost, sep, hostname_remainder = hostname.partition('.')
323 if not hostname_leftmost or not sep:
324 # wildcard must match at least one char
325 return False
326 return dn_remainder.lower() == hostname_remainder.lower()
327
328
329 def _inet_paton(ipname):
330 """Try to convert an IP address to packed binary form
331
332 Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6
333 support.
334 """
335 # inet_aton() also accepts strings like '1', '127.1', some also trailing
336 # data like '127.0.0.1 whatever'.
337 try:
338 addr = _socket.inet_aton(ipname)
339 except OSError:
340 # not an IPv4 address
341 pass
342 else:
343 if _socket.inet_ntoa(addr) == ipname:
344 # only accept injective ipnames
345 return addr
346 else:
347 # refuse for short IPv4 notation and additional trailing data
348 raise ValueError(
349 "{!r} is not a quad-dotted IPv4 address.".format(ipname)
350 )
351
352 try:
353 return _socket.inet_pton(_socket.AF_INET6, ipname)
354 except OSError:
355 raise ValueError("{!r} is neither an IPv4 nor an IP6 "
356 "address.".format(ipname))
357 except AttributeError:
358 # AF_INET6 not available
359 pass
360
361 raise ValueError("{!r} is not an IPv4 address.".format(ipname))
362
363
364 def _ipaddress_match(cert_ipaddress, host_ip):
365 """Exact matching of IP addresses.
366
367 RFC 6125 explicitly doesn't define an algorithm for this
368 (section 1.7.2 - "Out of Scope").
369 """
370 # OpenSSL may add a trailing newline to a subjectAltName's IP address,
371 # commonly with IPv6 addresses. Strip off trailing \n.
372 ip = _inet_paton(cert_ipaddress.rstrip())
373 return ip == host_ip
374
375
376 def match_hostname(cert, hostname):
377 """Verify that *cert* (in decoded format as returned by
378 SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
379 rules are followed.
380
381 The function matches IP addresses rather than dNSNames if hostname is a
382 valid ipaddress string. IPv4 addresses are supported on all platforms.
383 IPv6 addresses are supported on platforms with IPv6 support (AF_INET6
384 and inet_pton).
385
386 CertificateError is raised on failure. On success, the function
387 returns nothing.
388 """
389 warnings.warn(
390 "ssl.match_hostname() is deprecated",
391 category=DeprecationWarning,
392 stacklevel=2
393 )
394 if not cert:
395 raise ValueError("empty or no certificate, match_hostname needs a "
396 "SSL socket or SSL context with either "
397 "CERT_OPTIONAL or CERT_REQUIRED")
398 try:
399 host_ip = _inet_paton(hostname)
400 except ValueError:
401 # Not an IP address (common case)
402 host_ip = None
403 dnsnames = []
404 san = cert.get('subjectAltName', ())
405 for key, value in san:
406 if key == 'DNS':
407 if host_ip is None and _dnsname_match(value, hostname):
408 return
409 dnsnames.append(value)
410 elif key == 'IP Address':
411 if host_ip is not None and _ipaddress_match(value, host_ip):
412 return
413 dnsnames.append(value)
414 if not dnsnames:
415 # The subject is only checked when there is no dNSName entry
416 # in subjectAltName
417 for sub in cert.get('subject', ()):
418 for key, value in sub:
419 # XXX according to RFC 2818, the most specific Common Name
420 # must be used.
421 if key == 'commonName':
422 if _dnsname_match(value, hostname):
423 return
424 dnsnames.append(value)
425 if len(dnsnames) > 1:
426 raise CertificateError("hostname %r "
427 "doesn't match either of %s"
428 % (hostname, ', '.join(map(repr, dnsnames))))
429 elif len(dnsnames) == 1:
430 raise CertificateError("hostname %r "
431 "doesn't match %r"
432 % (hostname, dnsnames[0]))
433 else:
434 raise CertificateError("no appropriate commonName or "
435 "subjectAltName fields were found")
436
437
438 DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
439 "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
440 "openssl_capath")
441
442 def get_default_verify_paths():
443 """Return paths to default cafile and capath.
444 """
445 parts = _ssl.get_default_verify_paths()
446
447 # environment vars shadow paths
448 cafile = os.environ.get(parts[0], parts[1])
449 capath = os.environ.get(parts[2], parts[3])
450
451 return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
452 capath if os.path.isdir(capath) else None,
453 *parts)
454
455
456 class ESC[4;38;5;81m_ASN1Object(ESC[4;38;5;149mnamedtuple("_ASN1Object", "nid shortname longname oid")):
457 """ASN.1 object identifier lookup
458 """
459 __slots__ = ()
460
461 def __new__(cls, oid):
462 return super().__new__(cls, *_txt2obj(oid, name=False))
463
464 @classmethod
465 def fromnid(cls, nid):
466 """Create _ASN1Object from OpenSSL numeric ID
467 """
468 return super().__new__(cls, *_nid2obj(nid))
469
470 @classmethod
471 def fromname(cls, name):
472 """Create _ASN1Object from short name, long name or OID
473 """
474 return super().__new__(cls, *_txt2obj(name, name=True))
475
476
477 class ESC[4;38;5;81mPurpose(ESC[4;38;5;149m_ASN1Object, ESC[4;38;5;149m_Enum):
478 """SSLContext purpose flags with X509v3 Extended Key Usage objects
479 """
480 SERVER_AUTH = '1.3.6.1.5.5.7.3.1'
481 CLIENT_AUTH = '1.3.6.1.5.5.7.3.2'
482
483
484 class ESC[4;38;5;81mSSLContext(ESC[4;38;5;149m_SSLContext):
485 """An SSLContext holds various SSL-related configuration options and
486 data, such as certificates and possibly a private key."""
487 _windows_cert_stores = ("CA", "ROOT")
488
489 sslsocket_class = None # SSLSocket is assigned later.
490 sslobject_class = None # SSLObject is assigned later.
491
492 def __new__(cls, protocol=None, *args, **kwargs):
493 if protocol is None:
494 warnings.warn(
495 "ssl.SSLContext() without protocol argument is deprecated.",
496 category=DeprecationWarning,
497 stacklevel=2
498 )
499 protocol = PROTOCOL_TLS
500 self = _SSLContext.__new__(cls, protocol)
501 return self
502
503 def _encode_hostname(self, hostname):
504 if hostname is None:
505 return None
506 elif isinstance(hostname, str):
507 return hostname.encode('idna').decode('ascii')
508 else:
509 return hostname.decode('ascii')
510
511 def wrap_socket(self, sock, server_side=False,
512 do_handshake_on_connect=True,
513 suppress_ragged_eofs=True,
514 server_hostname=None, session=None):
515 # SSLSocket class handles server_hostname encoding before it calls
516 # ctx._wrap_socket()
517 return self.sslsocket_class._create(
518 sock=sock,
519 server_side=server_side,
520 do_handshake_on_connect=do_handshake_on_connect,
521 suppress_ragged_eofs=suppress_ragged_eofs,
522 server_hostname=server_hostname,
523 context=self,
524 session=session
525 )
526
527 def wrap_bio(self, incoming, outgoing, server_side=False,
528 server_hostname=None, session=None):
529 # Need to encode server_hostname here because _wrap_bio() can only
530 # handle ASCII str.
531 return self.sslobject_class._create(
532 incoming, outgoing, server_side=server_side,
533 server_hostname=self._encode_hostname(server_hostname),
534 session=session, context=self,
535 )
536
537 def set_npn_protocols(self, npn_protocols):
538 warnings.warn(
539 "ssl NPN is deprecated, use ALPN instead",
540 DeprecationWarning,
541 stacklevel=2
542 )
543 protos = bytearray()
544 for protocol in npn_protocols:
545 b = bytes(protocol, 'ascii')
546 if len(b) == 0 or len(b) > 255:
547 raise SSLError('NPN protocols must be 1 to 255 in length')
548 protos.append(len(b))
549 protos.extend(b)
550
551 self._set_npn_protocols(protos)
552
553 def set_servername_callback(self, server_name_callback):
554 if server_name_callback is None:
555 self.sni_callback = None
556 else:
557 if not callable(server_name_callback):
558 raise TypeError("not a callable object")
559
560 def shim_cb(sslobj, servername, sslctx):
561 servername = self._encode_hostname(servername)
562 return server_name_callback(sslobj, servername, sslctx)
563
564 self.sni_callback = shim_cb
565
566 def set_alpn_protocols(self, alpn_protocols):
567 protos = bytearray()
568 for protocol in alpn_protocols:
569 b = bytes(protocol, 'ascii')
570 if len(b) == 0 or len(b) > 255:
571 raise SSLError('ALPN protocols must be 1 to 255 in length')
572 protos.append(len(b))
573 protos.extend(b)
574
575 self._set_alpn_protocols(protos)
576
577 def _load_windows_store_certs(self, storename, purpose):
578 certs = bytearray()
579 try:
580 for cert, encoding, trust in enum_certificates(storename):
581 # CA certs are never PKCS#7 encoded
582 if encoding == "x509_asn":
583 if trust is True or purpose.oid in trust:
584 certs.extend(cert)
585 except PermissionError:
586 warnings.warn("unable to enumerate Windows certificate store")
587 if certs:
588 self.load_verify_locations(cadata=certs)
589 return certs
590
591 def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
592 if not isinstance(purpose, _ASN1Object):
593 raise TypeError(purpose)
594 if sys.platform == "win32":
595 for storename in self._windows_cert_stores:
596 self._load_windows_store_certs(storename, purpose)
597 self.set_default_verify_paths()
598
599 if hasattr(_SSLContext, 'minimum_version'):
600 @property
601 def minimum_version(self):
602 return TLSVersion(super().minimum_version)
603
604 @minimum_version.setter
605 def minimum_version(self, value):
606 if value == TLSVersion.SSLv3:
607 self.options &= ~Options.OP_NO_SSLv3
608 super(SSLContext, SSLContext).minimum_version.__set__(self, value)
609
610 @property
611 def maximum_version(self):
612 return TLSVersion(super().maximum_version)
613
614 @maximum_version.setter
615 def maximum_version(self, value):
616 super(SSLContext, SSLContext).maximum_version.__set__(self, value)
617
618 @property
619 def options(self):
620 return Options(super().options)
621
622 @options.setter
623 def options(self, value):
624 super(SSLContext, SSLContext).options.__set__(self, value)
625
626 if hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT'):
627 @property
628 def hostname_checks_common_name(self):
629 ncs = self._host_flags & _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
630 return ncs != _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
631
632 @hostname_checks_common_name.setter
633 def hostname_checks_common_name(self, value):
634 if value:
635 self._host_flags &= ~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
636 else:
637 self._host_flags |= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT
638 else:
639 @property
640 def hostname_checks_common_name(self):
641 return True
642
643 @property
644 def _msg_callback(self):
645 """TLS message callback
646
647 The message callback provides a debugging hook to analyze TLS
648 connections. The callback is called for any TLS protocol message
649 (header, handshake, alert, and more), but not for application data.
650 Due to technical limitations, the callback can't be used to filter
651 traffic or to abort a connection. Any exception raised in the
652 callback is delayed until the handshake, read, or write operation
653 has been performed.
654
655 def msg_cb(conn, direction, version, content_type, msg_type, data):
656 pass
657
658 conn
659 :class:`SSLSocket` or :class:`SSLObject` instance
660 direction
661 ``read`` or ``write``
662 version
663 :class:`TLSVersion` enum member or int for unknown version. For a
664 frame header, it's the header version.
665 content_type
666 :class:`_TLSContentType` enum member or int for unsupported
667 content type.
668 msg_type
669 Either a :class:`_TLSContentType` enum number for a header
670 message, a :class:`_TLSAlertType` enum member for an alert
671 message, a :class:`_TLSMessageType` enum member for other
672 messages, or int for unsupported message types.
673 data
674 Raw, decrypted message content as bytes
675 """
676 inner = super()._msg_callback
677 if inner is not None:
678 return inner.user_function
679 else:
680 return None
681
682 @_msg_callback.setter
683 def _msg_callback(self, callback):
684 if callback is None:
685 super(SSLContext, SSLContext)._msg_callback.__set__(self, None)
686 return
687
688 if not hasattr(callback, '__call__'):
689 raise TypeError(f"{callback} is not callable.")
690
691 def inner(conn, direction, version, content_type, msg_type, data):
692 try:
693 version = TLSVersion(version)
694 except ValueError:
695 pass
696
697 try:
698 content_type = _TLSContentType(content_type)
699 except ValueError:
700 pass
701
702 if content_type == _TLSContentType.HEADER:
703 msg_enum = _TLSContentType
704 elif content_type == _TLSContentType.ALERT:
705 msg_enum = _TLSAlertType
706 else:
707 msg_enum = _TLSMessageType
708 try:
709 msg_type = msg_enum(msg_type)
710 except ValueError:
711 pass
712
713 return callback(conn, direction, version,
714 content_type, msg_type, data)
715
716 inner.user_function = callback
717
718 super(SSLContext, SSLContext)._msg_callback.__set__(self, inner)
719
720 @property
721 def protocol(self):
722 return _SSLMethod(super().protocol)
723
724 @property
725 def verify_flags(self):
726 return VerifyFlags(super().verify_flags)
727
728 @verify_flags.setter
729 def verify_flags(self, value):
730 super(SSLContext, SSLContext).verify_flags.__set__(self, value)
731
732 @property
733 def verify_mode(self):
734 value = super().verify_mode
735 try:
736 return VerifyMode(value)
737 except ValueError:
738 return value
739
740 @verify_mode.setter
741 def verify_mode(self, value):
742 super(SSLContext, SSLContext).verify_mode.__set__(self, value)
743
744
745 def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
746 capath=None, cadata=None):
747 """Create a SSLContext object with default settings.
748
749 NOTE: The protocol and settings may change anytime without prior
750 deprecation. The values represent a fair balance between maximum
751 compatibility and security.
752 """
753 if not isinstance(purpose, _ASN1Object):
754 raise TypeError(purpose)
755
756 # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
757 # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
758 # by default.
759 if purpose == Purpose.SERVER_AUTH:
760 # verify certs and host name in client mode
761 context = SSLContext(PROTOCOL_TLS_CLIENT)
762 context.verify_mode = CERT_REQUIRED
763 context.check_hostname = True
764 elif purpose == Purpose.CLIENT_AUTH:
765 context = SSLContext(PROTOCOL_TLS_SERVER)
766 else:
767 raise ValueError(purpose)
768
769 if cafile or capath or cadata:
770 context.load_verify_locations(cafile, capath, cadata)
771 elif context.verify_mode != CERT_NONE:
772 # no explicit cafile, capath or cadata but the verify mode is
773 # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
774 # root CA certificates for the given purpose. This may fail silently.
775 context.load_default_certs(purpose)
776 # OpenSSL 1.1.1 keylog file
777 if hasattr(context, 'keylog_filename'):
778 keylogfile = os.environ.get('SSLKEYLOGFILE')
779 if keylogfile and not sys.flags.ignore_environment:
780 context.keylog_filename = keylogfile
781 return context
782
783 def _create_unverified_context(protocol=None, *, cert_reqs=CERT_NONE,
784 check_hostname=False, purpose=Purpose.SERVER_AUTH,
785 certfile=None, keyfile=None,
786 cafile=None, capath=None, cadata=None):
787 """Create a SSLContext object for Python stdlib modules
788
789 All Python stdlib modules shall use this function to create SSLContext
790 objects in order to keep common settings in one place. The configuration
791 is less restrict than create_default_context()'s to increase backward
792 compatibility.
793 """
794 if not isinstance(purpose, _ASN1Object):
795 raise TypeError(purpose)
796
797 # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
798 # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
799 # by default.
800 if purpose == Purpose.SERVER_AUTH:
801 # verify certs and host name in client mode
802 if protocol is None:
803 protocol = PROTOCOL_TLS_CLIENT
804 elif purpose == Purpose.CLIENT_AUTH:
805 if protocol is None:
806 protocol = PROTOCOL_TLS_SERVER
807 else:
808 raise ValueError(purpose)
809
810 context = SSLContext(protocol)
811 context.check_hostname = check_hostname
812 if cert_reqs is not None:
813 context.verify_mode = cert_reqs
814 if check_hostname:
815 context.check_hostname = True
816
817 if keyfile and not certfile:
818 raise ValueError("certfile must be specified")
819 if certfile or keyfile:
820 context.load_cert_chain(certfile, keyfile)
821
822 # load CA root certs
823 if cafile or capath or cadata:
824 context.load_verify_locations(cafile, capath, cadata)
825 elif context.verify_mode != CERT_NONE:
826 # no explicit cafile, capath or cadata but the verify mode is
827 # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
828 # root CA certificates for the given purpose. This may fail silently.
829 context.load_default_certs(purpose)
830 # OpenSSL 1.1.1 keylog file
831 if hasattr(context, 'keylog_filename'):
832 keylogfile = os.environ.get('SSLKEYLOGFILE')
833 if keylogfile and not sys.flags.ignore_environment:
834 context.keylog_filename = keylogfile
835 return context
836
837 # Used by http.client if no context is explicitly passed.
838 _create_default_https_context = create_default_context
839
840
841 # Backwards compatibility alias, even though it's not a public name.
842 _create_stdlib_context = _create_unverified_context
843
844
845 class ESC[4;38;5;81mSSLObject:
846 """This class implements an interface on top of a low-level SSL object as
847 implemented by OpenSSL. This object captures the state of an SSL connection
848 but does not provide any network IO itself. IO needs to be performed
849 through separate "BIO" objects which are OpenSSL's IO abstraction layer.
850
851 This class does not have a public constructor. Instances are returned by
852 ``SSLContext.wrap_bio``. This class is typically used by framework authors
853 that want to implement asynchronous IO for SSL through memory buffers.
854
855 When compared to ``SSLSocket``, this object lacks the following features:
856
857 * Any form of network IO, including methods such as ``recv`` and ``send``.
858 * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
859 """
860 def __init__(self, *args, **kwargs):
861 raise TypeError(
862 f"{self.__class__.__name__} does not have a public "
863 f"constructor. Instances are returned by SSLContext.wrap_bio()."
864 )
865
866 @classmethod
867 def _create(cls, incoming, outgoing, server_side=False,
868 server_hostname=None, session=None, context=None):
869 self = cls.__new__(cls)
870 sslobj = context._wrap_bio(
871 incoming, outgoing, server_side=server_side,
872 server_hostname=server_hostname,
873 owner=self, session=session
874 )
875 self._sslobj = sslobj
876 return self
877
878 @property
879 def context(self):
880 """The SSLContext that is currently in use."""
881 return self._sslobj.context
882
883 @context.setter
884 def context(self, ctx):
885 self._sslobj.context = ctx
886
887 @property
888 def session(self):
889 """The SSLSession for client socket."""
890 return self._sslobj.session
891
892 @session.setter
893 def session(self, session):
894 self._sslobj.session = session
895
896 @property
897 def session_reused(self):
898 """Was the client session reused during handshake"""
899 return self._sslobj.session_reused
900
901 @property
902 def server_side(self):
903 """Whether this is a server-side socket."""
904 return self._sslobj.server_side
905
906 @property
907 def server_hostname(self):
908 """The currently set server hostname (for SNI), or ``None`` if no
909 server hostname is set."""
910 return self._sslobj.server_hostname
911
912 def read(self, len=1024, buffer=None):
913 """Read up to 'len' bytes from the SSL object and return them.
914
915 If 'buffer' is provided, read into this buffer and return the number of
916 bytes read.
917 """
918 if buffer is not None:
919 v = self._sslobj.read(len, buffer)
920 else:
921 v = self._sslobj.read(len)
922 return v
923
924 def write(self, data):
925 """Write 'data' to the SSL object and return the number of bytes
926 written.
927
928 The 'data' argument must support the buffer interface.
929 """
930 return self._sslobj.write(data)
931
932 def getpeercert(self, binary_form=False):
933 """Returns a formatted version of the data in the certificate provided
934 by the other end of the SSL channel.
935
936 Return None if no certificate was provided, {} if a certificate was
937 provided, but not validated.
938 """
939 return self._sslobj.getpeercert(binary_form)
940
941 def selected_npn_protocol(self):
942 """Return the currently selected NPN protocol as a string, or ``None``
943 if a next protocol was not negotiated or if NPN is not supported by one
944 of the peers."""
945 warnings.warn(
946 "ssl NPN is deprecated, use ALPN instead",
947 DeprecationWarning,
948 stacklevel=2
949 )
950
951 def selected_alpn_protocol(self):
952 """Return the currently selected ALPN protocol as a string, or ``None``
953 if a next protocol was not negotiated or if ALPN is not supported by one
954 of the peers."""
955 return self._sslobj.selected_alpn_protocol()
956
957 def cipher(self):
958 """Return the currently selected cipher as a 3-tuple ``(name,
959 ssl_version, secret_bits)``."""
960 return self._sslobj.cipher()
961
962 def shared_ciphers(self):
963 """Return a list of ciphers shared by the client during the handshake or
964 None if this is not a valid server connection.
965 """
966 return self._sslobj.shared_ciphers()
967
968 def compression(self):
969 """Return the current compression algorithm in use, or ``None`` if
970 compression was not negotiated or not supported by one of the peers."""
971 return self._sslobj.compression()
972
973 def pending(self):
974 """Return the number of bytes that can be read immediately."""
975 return self._sslobj.pending()
976
977 def do_handshake(self):
978 """Start the SSL/TLS handshake."""
979 self._sslobj.do_handshake()
980
981 def unwrap(self):
982 """Start the SSL shutdown handshake."""
983 return self._sslobj.shutdown()
984
985 def get_channel_binding(self, cb_type="tls-unique"):
986 """Get channel binding data for current connection. Raise ValueError
987 if the requested `cb_type` is not supported. Return bytes of the data
988 or None if the data is not available (e.g. before the handshake)."""
989 return self._sslobj.get_channel_binding(cb_type)
990
991 def version(self):
992 """Return a string identifying the protocol version used by the
993 current SSL channel. """
994 return self._sslobj.version()
995
996 def verify_client_post_handshake(self):
997 return self._sslobj.verify_client_post_handshake()
998
999
1000 def _sslcopydoc(func):
1001 """Copy docstring from SSLObject to SSLSocket"""
1002 func.__doc__ = getattr(SSLObject, func.__name__).__doc__
1003 return func
1004
1005
1006 class ESC[4;38;5;81mSSLSocket(ESC[4;38;5;149msocket):
1007 """This class implements a subtype of socket.socket that wraps
1008 the underlying OS socket in an SSL context when necessary, and
1009 provides read and write methods over that channel. """
1010
1011 def __init__(self, *args, **kwargs):
1012 raise TypeError(
1013 f"{self.__class__.__name__} does not have a public "
1014 f"constructor. Instances are returned by "
1015 f"SSLContext.wrap_socket()."
1016 )
1017
1018 @classmethod
1019 def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
1020 suppress_ragged_eofs=True, server_hostname=None,
1021 context=None, session=None):
1022 if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
1023 raise NotImplementedError("only stream sockets are supported")
1024 if server_side:
1025 if server_hostname:
1026 raise ValueError("server_hostname can only be specified "
1027 "in client mode")
1028 if session is not None:
1029 raise ValueError("session can only be specified in "
1030 "client mode")
1031 if context.check_hostname and not server_hostname:
1032 raise ValueError("check_hostname requires server_hostname")
1033
1034 kwargs = dict(
1035 family=sock.family, type=sock.type, proto=sock.proto,
1036 fileno=sock.fileno()
1037 )
1038 self = cls.__new__(cls, **kwargs)
1039 super(SSLSocket, self).__init__(**kwargs)
1040 sock_timeout = sock.gettimeout()
1041 sock.detach()
1042
1043 self._context = context
1044 self._session = session
1045 self._closed = False
1046 self._sslobj = None
1047 self.server_side = server_side
1048 self.server_hostname = context._encode_hostname(server_hostname)
1049 self.do_handshake_on_connect = do_handshake_on_connect
1050 self.suppress_ragged_eofs = suppress_ragged_eofs
1051
1052 # See if we are connected
1053 try:
1054 self.getpeername()
1055 except OSError as e:
1056 if e.errno != errno.ENOTCONN:
1057 raise
1058 connected = False
1059 blocking = self.getblocking()
1060 self.setblocking(False)
1061 try:
1062 # We are not connected so this is not supposed to block, but
1063 # testing revealed otherwise on macOS and Windows so we do
1064 # the non-blocking dance regardless. Our raise when any data
1065 # is found means consuming the data is harmless.
1066 notconn_pre_handshake_data = self.recv(1)
1067 except OSError as e:
1068 # EINVAL occurs for recv(1) on non-connected on unix sockets.
1069 if e.errno not in (errno.ENOTCONN, errno.EINVAL):
1070 raise
1071 notconn_pre_handshake_data = b''
1072 self.setblocking(blocking)
1073 if notconn_pre_handshake_data:
1074 # This prevents pending data sent to the socket before it was
1075 # closed from escaping to the caller who could otherwise
1076 # presume it came through a successful TLS connection.
1077 reason = "Closed before TLS handshake with data in recv buffer."
1078 notconn_pre_handshake_data_error = SSLError(e.errno, reason)
1079 # Add the SSLError attributes that _ssl.c always adds.
1080 notconn_pre_handshake_data_error.reason = reason
1081 notconn_pre_handshake_data_error.library = None
1082 try:
1083 self.close()
1084 except OSError:
1085 pass
1086 try:
1087 raise notconn_pre_handshake_data_error
1088 finally:
1089 # Explicitly break the reference cycle.
1090 notconn_pre_handshake_data_error = None
1091 else:
1092 connected = True
1093
1094 self.settimeout(sock_timeout) # Must come after setblocking() calls.
1095 self._connected = connected
1096 if connected:
1097 # create the SSL object
1098 try:
1099 self._sslobj = self._context._wrap_socket(
1100 self, server_side, self.server_hostname,
1101 owner=self, session=self._session,
1102 )
1103 if do_handshake_on_connect:
1104 timeout = self.gettimeout()
1105 if timeout == 0.0:
1106 # non-blocking
1107 raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
1108 self.do_handshake()
1109 except (OSError, ValueError):
1110 self.close()
1111 raise
1112 return self
1113
1114 @property
1115 @_sslcopydoc
1116 def context(self):
1117 return self._context
1118
1119 @context.setter
1120 def context(self, ctx):
1121 self._context = ctx
1122 self._sslobj.context = ctx
1123
1124 @property
1125 @_sslcopydoc
1126 def session(self):
1127 if self._sslobj is not None:
1128 return self._sslobj.session
1129
1130 @session.setter
1131 def session(self, session):
1132 self._session = session
1133 if self._sslobj is not None:
1134 self._sslobj.session = session
1135
1136 @property
1137 @_sslcopydoc
1138 def session_reused(self):
1139 if self._sslobj is not None:
1140 return self._sslobj.session_reused
1141
1142 def dup(self):
1143 raise NotImplementedError("Can't dup() %s instances" %
1144 self.__class__.__name__)
1145
1146 def _checkClosed(self, msg=None):
1147 # raise an exception here if you wish to check for spurious closes
1148 pass
1149
1150 def _check_connected(self):
1151 if not self._connected:
1152 # getpeername() will raise ENOTCONN if the socket is really
1153 # not connected; note that we can be connected even without
1154 # _connected being set, e.g. if connect() first returned
1155 # EAGAIN.
1156 self.getpeername()
1157
1158 def read(self, len=1024, buffer=None):
1159 """Read up to LEN bytes and return them.
1160 Return zero-length string on EOF."""
1161
1162 self._checkClosed()
1163 if self._sslobj is None:
1164 raise ValueError("Read on closed or unwrapped SSL socket.")
1165 try:
1166 if buffer is not None:
1167 return self._sslobj.read(len, buffer)
1168 else:
1169 return self._sslobj.read(len)
1170 except SSLError as x:
1171 if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
1172 if buffer is not None:
1173 return 0
1174 else:
1175 return b''
1176 else:
1177 raise
1178
1179 def write(self, data):
1180 """Write DATA to the underlying SSL channel. Returns
1181 number of bytes of DATA actually transmitted."""
1182
1183 self._checkClosed()
1184 if self._sslobj is None:
1185 raise ValueError("Write on closed or unwrapped SSL socket.")
1186 return self._sslobj.write(data)
1187
1188 @_sslcopydoc
1189 def getpeercert(self, binary_form=False):
1190 self._checkClosed()
1191 self._check_connected()
1192 return self._sslobj.getpeercert(binary_form)
1193
1194 @_sslcopydoc
1195 def selected_npn_protocol(self):
1196 self._checkClosed()
1197 warnings.warn(
1198 "ssl NPN is deprecated, use ALPN instead",
1199 DeprecationWarning,
1200 stacklevel=2
1201 )
1202 return None
1203
1204 @_sslcopydoc
1205 def selected_alpn_protocol(self):
1206 self._checkClosed()
1207 if self._sslobj is None or not _ssl.HAS_ALPN:
1208 return None
1209 else:
1210 return self._sslobj.selected_alpn_protocol()
1211
1212 @_sslcopydoc
1213 def cipher(self):
1214 self._checkClosed()
1215 if self._sslobj is None:
1216 return None
1217 else:
1218 return self._sslobj.cipher()
1219
1220 @_sslcopydoc
1221 def shared_ciphers(self):
1222 self._checkClosed()
1223 if self._sslobj is None:
1224 return None
1225 else:
1226 return self._sslobj.shared_ciphers()
1227
1228 @_sslcopydoc
1229 def compression(self):
1230 self._checkClosed()
1231 if self._sslobj is None:
1232 return None
1233 else:
1234 return self._sslobj.compression()
1235
1236 def send(self, data, flags=0):
1237 self._checkClosed()
1238 if self._sslobj is not None:
1239 if flags != 0:
1240 raise ValueError(
1241 "non-zero flags not allowed in calls to send() on %s" %
1242 self.__class__)
1243 return self._sslobj.write(data)
1244 else:
1245 return super().send(data, flags)
1246
1247 def sendto(self, data, flags_or_addr, addr=None):
1248 self._checkClosed()
1249 if self._sslobj is not None:
1250 raise ValueError("sendto not allowed on instances of %s" %
1251 self.__class__)
1252 elif addr is None:
1253 return super().sendto(data, flags_or_addr)
1254 else:
1255 return super().sendto(data, flags_or_addr, addr)
1256
1257 def sendmsg(self, *args, **kwargs):
1258 # Ensure programs don't send data unencrypted if they try to
1259 # use this method.
1260 raise NotImplementedError("sendmsg not allowed on instances of %s" %
1261 self.__class__)
1262
1263 def sendall(self, data, flags=0):
1264 self._checkClosed()
1265 if self._sslobj is not None:
1266 if flags != 0:
1267 raise ValueError(
1268 "non-zero flags not allowed in calls to sendall() on %s" %
1269 self.__class__)
1270 count = 0
1271 with memoryview(data) as view, view.cast("B") as byte_view:
1272 amount = len(byte_view)
1273 while count < amount:
1274 v = self.send(byte_view[count:])
1275 count += v
1276 else:
1277 return super().sendall(data, flags)
1278
1279 def sendfile(self, file, offset=0, count=None):
1280 """Send a file, possibly by using os.sendfile() if this is a
1281 clear-text socket. Return the total number of bytes sent.
1282 """
1283 if self._sslobj is not None:
1284 return self._sendfile_use_send(file, offset, count)
1285 else:
1286 # os.sendfile() works with plain sockets only
1287 return super().sendfile(file, offset, count)
1288
1289 def recv(self, buflen=1024, flags=0):
1290 self._checkClosed()
1291 if self._sslobj is not None:
1292 if flags != 0:
1293 raise ValueError(
1294 "non-zero flags not allowed in calls to recv() on %s" %
1295 self.__class__)
1296 return self.read(buflen)
1297 else:
1298 return super().recv(buflen, flags)
1299
1300 def recv_into(self, buffer, nbytes=None, flags=0):
1301 self._checkClosed()
1302 if nbytes is None:
1303 if buffer is not None:
1304 with memoryview(buffer) as view:
1305 nbytes = view.nbytes
1306 if not nbytes:
1307 nbytes = 1024
1308 else:
1309 nbytes = 1024
1310 if self._sslobj is not None:
1311 if flags != 0:
1312 raise ValueError(
1313 "non-zero flags not allowed in calls to recv_into() on %s" %
1314 self.__class__)
1315 return self.read(nbytes, buffer)
1316 else:
1317 return super().recv_into(buffer, nbytes, flags)
1318
1319 def recvfrom(self, buflen=1024, flags=0):
1320 self._checkClosed()
1321 if self._sslobj is not None:
1322 raise ValueError("recvfrom not allowed on instances of %s" %
1323 self.__class__)
1324 else:
1325 return super().recvfrom(buflen, flags)
1326
1327 def recvfrom_into(self, buffer, nbytes=None, flags=0):
1328 self._checkClosed()
1329 if self._sslobj is not None:
1330 raise ValueError("recvfrom_into not allowed on instances of %s" %
1331 self.__class__)
1332 else:
1333 return super().recvfrom_into(buffer, nbytes, flags)
1334
1335 def recvmsg(self, *args, **kwargs):
1336 raise NotImplementedError("recvmsg not allowed on instances of %s" %
1337 self.__class__)
1338
1339 def recvmsg_into(self, *args, **kwargs):
1340 raise NotImplementedError("recvmsg_into not allowed on instances of "
1341 "%s" % self.__class__)
1342
1343 @_sslcopydoc
1344 def pending(self):
1345 self._checkClosed()
1346 if self._sslobj is not None:
1347 return self._sslobj.pending()
1348 else:
1349 return 0
1350
1351 def shutdown(self, how):
1352 self._checkClosed()
1353 self._sslobj = None
1354 super().shutdown(how)
1355
1356 @_sslcopydoc
1357 def unwrap(self):
1358 if self._sslobj:
1359 s = self._sslobj.shutdown()
1360 self._sslobj = None
1361 return s
1362 else:
1363 raise ValueError("No SSL wrapper around " + str(self))
1364
1365 @_sslcopydoc
1366 def verify_client_post_handshake(self):
1367 if self._sslobj:
1368 return self._sslobj.verify_client_post_handshake()
1369 else:
1370 raise ValueError("No SSL wrapper around " + str(self))
1371
1372 def _real_close(self):
1373 self._sslobj = None
1374 super()._real_close()
1375
1376 @_sslcopydoc
1377 def do_handshake(self, block=False):
1378 self._check_connected()
1379 timeout = self.gettimeout()
1380 try:
1381 if timeout == 0.0 and block:
1382 self.settimeout(None)
1383 self._sslobj.do_handshake()
1384 finally:
1385 self.settimeout(timeout)
1386
1387 def _real_connect(self, addr, connect_ex):
1388 if self.server_side:
1389 raise ValueError("can't connect in server-side mode")
1390 # Here we assume that the socket is client-side, and not
1391 # connected at the time of the call. We connect it, then wrap it.
1392 if self._connected or self._sslobj is not None:
1393 raise ValueError("attempt to connect already-connected SSLSocket!")
1394 self._sslobj = self.context._wrap_socket(
1395 self, False, self.server_hostname,
1396 owner=self, session=self._session
1397 )
1398 try:
1399 if connect_ex:
1400 rc = super().connect_ex(addr)
1401 else:
1402 rc = None
1403 super().connect(addr)
1404 if not rc:
1405 self._connected = True
1406 if self.do_handshake_on_connect:
1407 self.do_handshake()
1408 return rc
1409 except (OSError, ValueError):
1410 self._sslobj = None
1411 raise
1412
1413 def connect(self, addr):
1414 """Connects to remote ADDR, and then wraps the connection in
1415 an SSL channel."""
1416 self._real_connect(addr, False)
1417
1418 def connect_ex(self, addr):
1419 """Connects to remote ADDR, and then wraps the connection in
1420 an SSL channel."""
1421 return self._real_connect(addr, True)
1422
1423 def accept(self):
1424 """Accepts a new connection from a remote client, and returns
1425 a tuple containing that new connection wrapped with a server-side
1426 SSL channel, and the address of the remote client."""
1427
1428 newsock, addr = super().accept()
1429 newsock = self.context.wrap_socket(newsock,
1430 do_handshake_on_connect=self.do_handshake_on_connect,
1431 suppress_ragged_eofs=self.suppress_ragged_eofs,
1432 server_side=True)
1433 return newsock, addr
1434
1435 @_sslcopydoc
1436 def get_channel_binding(self, cb_type="tls-unique"):
1437 if self._sslobj is not None:
1438 return self._sslobj.get_channel_binding(cb_type)
1439 else:
1440 if cb_type not in CHANNEL_BINDING_TYPES:
1441 raise ValueError(
1442 "{0} channel binding type not implemented".format(cb_type)
1443 )
1444 return None
1445
1446 @_sslcopydoc
1447 def version(self):
1448 if self._sslobj is not None:
1449 return self._sslobj.version()
1450 else:
1451 return None
1452
1453
1454 # Python does not support forward declaration of types.
1455 SSLContext.sslsocket_class = SSLSocket
1456 SSLContext.sslobject_class = SSLObject
1457
1458
1459 def wrap_socket(sock, keyfile=None, certfile=None,
1460 server_side=False, cert_reqs=CERT_NONE,
1461 ssl_version=PROTOCOL_TLS, ca_certs=None,
1462 do_handshake_on_connect=True,
1463 suppress_ragged_eofs=True,
1464 ciphers=None):
1465 warnings.warn(
1466 "ssl.wrap_socket() is deprecated, use SSLContext.wrap_socket()",
1467 category=DeprecationWarning,
1468 stacklevel=2
1469 )
1470 if server_side and not certfile:
1471 raise ValueError("certfile must be specified for server-side "
1472 "operations")
1473 if keyfile and not certfile:
1474 raise ValueError("certfile must be specified")
1475 context = SSLContext(ssl_version)
1476 context.verify_mode = cert_reqs
1477 if ca_certs:
1478 context.load_verify_locations(ca_certs)
1479 if certfile:
1480 context.load_cert_chain(certfile, keyfile)
1481 if ciphers:
1482 context.set_ciphers(ciphers)
1483 return context.wrap_socket(
1484 sock=sock, server_side=server_side,
1485 do_handshake_on_connect=do_handshake_on_connect,
1486 suppress_ragged_eofs=suppress_ragged_eofs
1487 )
1488
1489 # some utility functions
1490
1491 def cert_time_to_seconds(cert_time):
1492 """Return the time in seconds since the Epoch, given the timestring
1493 representing the "notBefore" or "notAfter" date from a certificate
1494 in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
1495
1496 "notBefore" or "notAfter" dates must use UTC (RFC 5280).
1497
1498 Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
1499 UTC should be specified as GMT (see ASN1_TIME_print())
1500 """
1501 from time import strptime
1502 from calendar import timegm
1503
1504 months = (
1505 "Jan","Feb","Mar","Apr","May","Jun",
1506 "Jul","Aug","Sep","Oct","Nov","Dec"
1507 )
1508 time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
1509 try:
1510 month_number = months.index(cert_time[:3].title()) + 1
1511 except ValueError:
1512 raise ValueError('time data %r does not match '
1513 'format "%%b%s"' % (cert_time, time_format))
1514 else:
1515 # found valid month
1516 tt = strptime(cert_time[3:], time_format)
1517 # return an integer, the previous mktime()-based implementation
1518 # returned a float (fractional seconds are always zero here).
1519 return timegm((tt[0], month_number) + tt[2:6])
1520
1521 PEM_HEADER = "-----BEGIN CERTIFICATE-----"
1522 PEM_FOOTER = "-----END CERTIFICATE-----"
1523
1524 def DER_cert_to_PEM_cert(der_cert_bytes):
1525 """Takes a certificate in binary DER format and returns the
1526 PEM version of it as a string."""
1527
1528 f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
1529 ss = [PEM_HEADER]
1530 ss += [f[i:i+64] for i in range(0, len(f), 64)]
1531 ss.append(PEM_FOOTER + '\n')
1532 return '\n'.join(ss)
1533
1534 def PEM_cert_to_DER_cert(pem_cert_string):
1535 """Takes a certificate in ASCII PEM format and returns the
1536 DER-encoded version of it as a byte sequence"""
1537
1538 if not pem_cert_string.startswith(PEM_HEADER):
1539 raise ValueError("Invalid PEM encoding; must start with %s"
1540 % PEM_HEADER)
1541 if not pem_cert_string.strip().endswith(PEM_FOOTER):
1542 raise ValueError("Invalid PEM encoding; must end with %s"
1543 % PEM_FOOTER)
1544 d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
1545 return base64.decodebytes(d.encode('ASCII', 'strict'))
1546
1547 def get_server_certificate(addr, ssl_version=PROTOCOL_TLS_CLIENT,
1548 ca_certs=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
1549 """Retrieve the certificate from the server at the specified address,
1550 and return it as a PEM-encoded string.
1551 If 'ca_certs' is specified, validate the server cert against it.
1552 If 'ssl_version' is specified, use it in the connection attempt.
1553 If 'timeout' is specified, use it in the connection attempt.
1554 """
1555
1556 host, port = addr
1557 if ca_certs is not None:
1558 cert_reqs = CERT_REQUIRED
1559 else:
1560 cert_reqs = CERT_NONE
1561 context = _create_stdlib_context(ssl_version,
1562 cert_reqs=cert_reqs,
1563 cafile=ca_certs)
1564 with create_connection(addr, timeout=timeout) as sock:
1565 with context.wrap_socket(sock, server_hostname=host) as sslsock:
1566 dercert = sslsock.getpeercert(True)
1567 return DER_cert_to_PEM_cert(dercert)
1568
1569 def get_protocol_name(protocol_code):
1570 return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')