diff --git a/eventlib/green/socket.py b/eventlib/green/socket.py index 60d3248..7a4623b 100644 --- a/eventlib/green/socket.py +++ b/eventlib/green/socket.py @@ -1,110 +1,109 @@ import socket as __socket -from socket import (__all__, error, AF_INET, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR, getdefaulttimeout, gethostname, getnameinfo, getservbyname, herror, htonl, SOCK_DGRAM, timeout, gaierror, SOCK_RAW, setdefaulttimeout, getservbyport, gethostbyaddr, ntohl, htons, ntohs, getfqdn, SOCK_RDM, SOCK_SEQPACKET) +from socket import (__all__, error, AF_INET, AF_UNIX, SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR, getdefaulttimeout, gethostname, getnameinfo, getservbyname, herror, htonl, SOCK_DGRAM, timeout, gaierror, SOCK_RAW, setdefaulttimeout, getservbyport, gethostbyaddr, ntohl, htons, ntohs, getfqdn, SOCK_RDM, SOCK_SEQPACKET) from eventlib.api import get_hub from eventlib.greenio import GreenSocket as socket from eventlib.greenio import socketpair, fromfd from socket import SocketIO +from ssl import SSLError as sslerror import warnings def gethostbyname(name): if getattr(get_hub(), 'uses_twisted_reactor', None): globals()['gethostbyname'] = _gethostbyname_twisted else: globals()['gethostbyname'] = _gethostbyname_tpool return globals()['gethostbyname'](name) def _gethostbyname_twisted(name): from twisted.internet import reactor from eventlib.twistedutil import block_on as _block_on return _block_on(reactor.resolve(name)) def _gethostbyname_tpool(name): from eventlib import tpool return tpool.execute( __socket.gethostbyname, name) def getaddrinfo(*args, **kw): if getattr(get_hub(), 'uses_twisted_reactor', None): globals()['getaddrinfo'] = _getaddrinfo_twisted else: globals()['getaddrinfo'] = _getaddrinfo_tpool return globals()['getaddrinfo'](*args, **kw) def _getaddrinfo_twisted(*args, **kw): from twisted.internet.threads import deferToThread from eventlib.twistedutil import block_on as _block_on return _block_on(deferToThread(__socket.getaddrinfo, *args, **kw)) def _getaddrinfo_tpool(*args, **kw): from eventlib import tpool return tpool.execute( __socket.getaddrinfo, *args, **kw) # XXX there're few more blocking functions in socket # XXX having a hub-independent way to access thread pool would be nice _GLOBAL_DEFAULT_TIMEOUT = __socket._GLOBAL_DEFAULT_TIMEOUT def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): """Connect to *address* and return the socket object. Convenience function. Connect to *address* (a 2-tuple ``(host, port)``) and return the socket object. Passing the optional *timeout* parameter will set the timeout on the socket instance before attempting to connect. If no *timeout* is supplied, the global default timeout setting returned by :func:`getdefaulttimeout` is used. """ msg = "getaddrinfo returns an empty list" host, port = address for res in getaddrinfo(host, port, 0, SOCK_STREAM): af, socktype, proto, canonname, sa = res sock = None try: sock = socket(af, socktype, proto) if timeout is not _GLOBAL_DEFAULT_TIMEOUT: sock.settimeout(timeout) sock.connect(sa) return sock except error as msg: if sock is not None: sock.close() raise error(msg) try: from eventlib.green import ssl as ssl_module except ImportError: # no ssl support pass else: # some constants the SSL module exports but not in __all__ from eventlib.green.ssl import (RAND_add, RAND_status, SSL_ERROR_ZERO_RETURN, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_X509_LOOKUP, SSL_ERROR_SYSCALL, SSL_ERROR_SSL, SSL_ERROR_WANT_CONNECT, SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE) try: - sslerror = __socket.sslerror - __socket.ssl def ssl(sock, certificate=None, private_key=None): warnings.warn("socket.ssl() is deprecated. Use ssl.wrap_socket() instead.", DeprecationWarning, stacklevel=2) return ssl_module.sslwrap_simple(sock, private_key, certificate) - except AttributeError: + except AttributeError as e: # if the real socket module doesn't have the ssl method or sslerror # exception, we don't emulate them pass diff --git a/eventlib/green/ssl.py b/eventlib/green/ssl.py index 8accc08..4372c3f 100644 --- a/eventlib/green/ssl.py +++ b/eventlib/green/ssl.py @@ -1,261 +1,260 @@ import ssl as __ssl from ssl import _ssl -from _ssl import ( RAND_add, RAND_status, SSL_ERROR_ZERO_RETURN, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_X509_LOOKUP, SSL_ERROR_SYSCALL, SSL_ERROR_SSL, SSL_ERROR_WANT_CONNECT, SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, SSLError ) +from _ssl import ( CERT_NONE, PROTOCOL_TLS, RAND_add, RAND_status, SSL_ERROR_ZERO_RETURN, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_X509_LOOKUP, SSL_ERROR_SYSCALL, SSL_ERROR_SSL, SSL_ERROR_WANT_CONNECT, SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, SSLError ) time = __import__('time') from eventlib.api import trampoline from eventlib.greenio import ( set_nonblocking, GreenSocket, GreenFile, CONNECT_ERR, CONNECT_SUCCESS, BLOCKING_ERR ) orig_socket = __import__('socket') socket = orig_socket.socket __patched__ = ['SSLSocket', 'wrap_socket', 'sslwrap_simple'] class GreenSSLSocket(__ssl.SSLSocket): """ This is a green version of the SSLSocket class from the ssl module added in 2.6. For documentation on it, please see the Python standard documentation. Python nonblocking ssl objects don't give errors when the other end of the socket is closed (they do notice when the other end is shutdown, though). Any write/read operations will simply hang if the socket is closed from the other end. There is no obvious fix for this problem; it appears to be a limitation of Python's ssl object implementation. A workaround is to set a reasonable timeout on the socket using settimeout(), and to close/reopen the connection when a timeout occurs at an unexpected juncture in the code. """ # we are inheriting from SSLSocket because its constructor calls # do_handshake whose behavior we wish to override def __init__(self, sock, *args, **kw): if not isinstance(sock, GreenSocket): sock = GreenSocket(sock) self.act_non_blocking = sock.act_non_blocking self._timeout = sock.gettimeout() super(GreenSSLSocket, self).__init__(sock.fd, *args, **kw) # the superclass initializer trashes the methods so we remove # the local-object versions of them and let the actual class # methods shine through try: for fn in orig_socket._delegate_methods: delattr(self, fn) except AttributeError: pass def settimeout(self, timeout): self._timeout = timeout def gettimeout(self): return self._timeout def setblocking(self, flag): if flag: self.act_non_blocking = False self._timeout = None else: self.act_non_blocking = True self._timeout = 0.0 def _call_trampolining(self, func, *a, **kw): if self.act_non_blocking: return func(*a, **kw) else: while True: try: return func(*a, **kw) except SSLError as e: if e.args[0] == SSL_ERROR_WANT_READ: trampoline(self, read=True, timeout=self.gettimeout(), timeout_exc=SSLError('timed out')) elif e.args[0] == SSL_ERROR_WANT_WRITE: trampoline(self, write=True, timeout=self.gettimeout(), timeout_exc=SSLError('timed out')) else: raise def write(self, data): """Write DATA to the underlying SSL channel. Returns number of bytes of DATA actually transmitted.""" return self._call_trampolining( super(GreenSSLSocket, self).write, data) def read(self, len=1024): """Read up to LEN bytes and return them. Return zero-length string on EOF.""" return self._call_trampolining( super(GreenSSLSocket, self).read, len) def send (self, data, flags=0): if not isinstance(data, bytes): data = data.encode() if self._sslobj: return self._call_trampolining( super(GreenSSLSocket, self).send, data, flags) else: trampoline(self, write=True, timeout_exc=SSLError('timed out')) return socket.send(self, data, flags) def sendto (self, data, addr, flags=0): # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is if self._sslobj: raise ValueError("sendto not allowed on instances of %s" % self.__class__) else: trampoline(self, write=True, timeout_exc=SSLError('timed out')) return socket.sendto(self, data, addr, flags) def sendall (self, data, flags=0): # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is if self._sslobj: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to sendall() on %s" % self.__class__) amount = len(data) count = 0 while (count < amount): v = self.send(data[count:]) count += v return amount else: trampoline(self, write=True, timeout_exc=SSLError('timed out')) return socket.sendall(self, data, flags) def recv(self, buflen=1024, flags=0): # *NOTE: gross, copied code from ssl.py becase it's not factored well enough to be used as-is if self._sslobj: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to recv() on %s" % self.__class__) read = self.read(buflen) return read else: return socket.recv(self, buflen, flags) def recv_into (self, buffer, nbytes=None, flags=0): if not self.act_non_blocking: trampoline(self, read=True, timeout=self.gettimeout(), timeout_exc=SSLError('timed out')) return super(GreenSSLSocket, self).recv_into(buffer, nbytes, flags) def recvfrom (self, addr, buflen=1024, flags=0): if not self.act_non_blocking: trampoline(self, read=True, timeout=self.gettimeout(), timeout_exc=SSLError('timed out')) return super(GreenSSLSocket, self).recvfrom(addr, buflen, flags) def recvfrom_into (self, buffer, nbytes=None, flags=0): if not self.act_non_blocking: trampoline(self, read=True, timeout=self.gettimeout(), timeout_exc=SSLError('timed out')) return super(GreenSSLSocket, self).recvfrom_into(buffer, nbytes, flags) def unwrap(self): return GreenSocket(super(GreenSSLSocket, self).unwrap()) def do_handshake(self): """Perform a TLS/SSL handshake.""" return self._call_trampolining( super(GreenSSLSocket, self).do_handshake) def _socket_connect(self, addr): real_connect = socket.connect if self.act_non_blocking: return real_connect(self, addr) else: # *NOTE: gross, copied code from greenio because it's not factored # well enough to reuse if self.gettimeout() is None: while True: try: return real_connect(self, addr) except orig_socket.error as e: if e.args[0] in CONNECT_ERR: trampoline(self, write=True) elif e.args[0] in CONNECT_SUCCESS: return else: raise else: end = time.time() + self.gettimeout() while True: try: real_connect(self, addr) except orig_socket.error as e: if e.args[0] in CONNECT_ERR: trampoline(self, write=True, timeout=end-time.time(), timeout_exc=SSLError('timed out')) elif e.args[0] in CONNECT_SUCCESS: return else: raise if time.time() >= end: raise SSLError('timed out') def connect(self, addr): """Connects to remote ADDR, and then wraps the connection in an SSL channel.""" # *NOTE: grrrrr copied this code from ssl.py because of the reference # to socket.connect which we don't want to call directly if self._sslobj: raise ValueError("attempt to connect already-connected SSLSocket!") self._socket_connect(addr) self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, self.cert_reqs, self.ssl_version, self.ca_certs, self.ciphers) if self.do_handshake_on_connect: self.do_handshake() def accept(self): """Accepts a new connection from a remote client, and returns a tuple containing that new connection wrapped with a server-side SSL channel, and the address of the remote client.""" # RDW grr duplication of code from greenio if self.act_non_blocking: newsock, addr = socket.accept(self) else: while True: try: newsock, addr = socket.accept(self) set_nonblocking(newsock) break except orig_socket.error as e: if e.args[0] not in BLOCKING_ERR: raise trampoline(self, read=True, timeout=self.gettimeout(), timeout_exc=SSLError('timed out')) new_ssl = type(self)(newsock, keyfile=self.keyfile, certfile=self.certfile, server_side=True, cert_reqs=self.cert_reqs, ssl_version=self.ssl_version, ca_certs=self.ca_certs, do_handshake_on_connect=self.do_handshake_on_connect, suppress_ragged_eofs=self.suppress_ragged_eofs) return (new_ssl, addr) def dup(self): raise NotImplementedError("Can't dup an ssl object") SSLSocket = GreenSSLSocket def wrap_socket(sock, *a, **kw): return GreenSSLSocket(sock, *a, **kw) -if hasattr(__ssl, 'sslwrap_simple'): - def sslwrap_simple(sock, keyfile=None, certfile=None): - """A replacement for the old socket.ssl function. Designed - for compability with Python 2.5 and earlier. Will disappear in - Python 3.0.""" - ssl_sock = GreenSSLSocket(sock, keyfile=keyfile, certfile=certfile, - server_side=False, - cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, - do_handshake_on_connect=False, - ca_certs=None) - return ssl_sock +def sslwrap_simple(sock, keyfile=None, certfile=None): + """A replacement for the old socket.ssl function. Designed + for compability with Python 2.5 and earlier. Will disappear in + Python 3.0.""" + ssl_sock = GreenSSLSocket(sock, keyfile=keyfile, certfile=certfile, + server_side=False, + cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_TLS, + do_handshake_on_connect=False, + ca_certs=None) + return ssl_sock