diff --git a/msrplib/connect.py b/msrplib/connect.py index d7eba4f..b1a7642 100644 --- a/msrplib/connect.py +++ b/msrplib/connect.py @@ -1,510 +1,512 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details """Establish MSRP connection. This module provides means to obtain a connected and bound MSRPTransport instance. It uniformly handles 3 different configurations you may find your client engaged in: 1. Calling endpoint, not using a relay (DirectConnector) 2. Answering endpoint, not using a relay (DirectAcceptor) 3. Endpoint using a relay (RelayConnection) The answering endpoint may skip using the relay if sure that it's accessible directly. The calling endpoint is unlikely to need the relay. Once you have an instance of the right class, the procedure to establish the connection is the same: full_local_path = connector.prepare() try: ... put full_local_path in SDP 'a:path' attribute ... get full_remote_path from remote's 'a:path: attribute ... (the order of the above steps is reversed if you're the ... answering party, but that does not affect connector's usage) msrptransport = connector.complete(full_remote_path) finally: connector.cleanup() To customize connection's parameters, create a new protocol.URI object and pass it to prepare() function, e.g. local_uri = protocol.URI(use_tls=False, port=5000) connector.prepare(local_uri) prepare() may update local_uri in place with the actual connection parameters used (e.g. if you specified port=0). 'port' attribute of local_uri is currently only respected by AcceptorDirect. Note that, acceptors and connectors are one-use only. MSRPServer, on the contrary, can be used multiple times. """ from __future__ import with_statement import random from twisted.internet.address import IPv4Address from twisted.names.srvconnect import SRVConnector from application.python import Null from application.system import host from eventlib.twistedutil.protocol import GreenClientCreator, SpawnFactory from eventlib import coros from eventlib.api import timeout, sleep from eventlib.green.socket import gethostbyname from gnutls.interfaces.twisted import TLSContext from msrplib import protocol, MSRPError from msrplib.transport import MSRPTransport, MSRPTransactionError, MSRPBadRequest, MSRPNoSuchSessionError from msrplib.digest import process_www_authenticate __all__ = ['MSRPRelaySettings', 'MSRPTimeout', 'MSRPConnectTimeout', 'MSRPRelayConnectTimeout', 'MSRPIncomingConnectTimeout', 'MSRPBindSessionTimeout', 'MSRPRelayAuthError', 'MSRPAuthTimeout', 'MSRPServer', 'DirectConnector', 'DirectAcceptor', 'RelayConnection'] + class MSRPRelaySettings(protocol.ConnectInfo): use_tls = True def __init__(self, domain, username, password, host=None, port=None, use_tls=None, credentials=None): protocol.ConnectInfo.__init__(self, host, use_tls=use_tls, port=port, credentials=credentials) self.domain = domain self.username = username self.password = password def __str__(self): result = "MSRPRelay %s://%s" % (self.scheme, self.host or self.domain) if self.port: result += ':%s' % self.port return result def __repr__(self): params = [self.domain, self.username, self.password, self.host, self.port] if params[-1] is None: del params[-1] if params[-1] is None: del params[-1] return '%s(%s)' % (type(self).__name__, ', '.join(repr(x) for x in params)) @property def uri_domain(self): - return protocol.URI(host=self.domain, port=self.port, use_tls=self.use_tls, session_id='') + return protocol.URI(host=self.domain, port=self.port, use_tls=self.use_tls, session_id='') + class TimeoutMixin(object): @classmethod def timeout(cls, *throw_args): if not throw_args: throw_args = (cls, ) return timeout(cls.seconds, *throw_args) class MSRPTimeout(MSRPError, TimeoutMixin): seconds = 30 class MSRPConnectTimeout(MSRPTimeout): pass class MSRPRelayConnectTimeout(MSRPTimeout): pass class MSRPIncomingConnectTimeout(MSRPTimeout): pass class MSRPBindSessionTimeout(MSRPTimeout): pass class MSRPRelayAuthError(MSRPTransactionError): pass class MSRPAuthTimeout(MSRPTransactionError, TimeoutMixin): code = 408 comment = 'No response to AUTH request' seconds = 30 class MSRPSRVConnector(SRVConnector): def pickServer(self): assert self.servers is not None assert self.orderedServers is not None if not self.servers and not self.orderedServers: # no SRV record, fall back.. return self.domain, 2855 return SRVConnector.pickServer(self) class ConnectBase(object): SRVConnectorClass = MSRPSRVConnector def __init__(self, logger=Null, use_sessmatch=False): self.logger = logger self.use_sessmatch = use_sessmatch self.local_uri = None def _connect(self, local_uri, remote_uri): self.logger.info('Connecting to %s', remote_uri) creator = GreenClientCreator(gtransport_class=MSRPTransport, local_uri=local_uri, logger=self.logger, use_sessmatch=self.use_sessmatch) if remote_uri.host: if remote_uri.use_tls: msrp = creator.connectTLS(remote_uri.host, remote_uri.port or 2855, TLSContext(local_uri.credentials)) else: msrp = creator.connectTCP(remote_uri.host, remote_uri.port or 2855) else: if not remote_uri.domain: raise ValueError("remote_uri must have either 'host' or 'domain'") if remote_uri.use_tls: connectFuncName = 'connectTLS' connectFuncArgs = (TLSContext(local_uri.credentials),) else: connectFuncName = 'connectTCP' connectFuncArgs = () msrp = creator.connectSRV(remote_uri.scheme, remote_uri.domain, connectFuncName=connectFuncName, connectFuncArgs=connectFuncArgs, ConnectorClass=self.SRVConnectorClass) remote_address = msrp.getPeer() self.logger.info('Connected to %s:%s', remote_address.host, remote_address.port) return msrp def _listen(self, local_uri, factory): from twisted.internet import reactor if local_uri.use_tls: port = reactor.listenTLS(local_uri.port or 0, factory, TLSContext(local_uri.credentials), interface=local_uri.host) else: port = reactor.listenTCP(local_uri.port or 0, factory, interface=local_uri.host) local_address = port.getHost() local_uri.port = local_address.port self.logger.info('Listening for incoming %s connections on %s:%s', local_uri.scheme.upper(), local_address.host, local_address.port) return port def cleanup(self, wait=True): pass class DirectConnector(ConnectBase): def __init__(self, logger=Null, use_sessmatch=False): ConnectBase.__init__(self, logger, use_sessmatch) self.host_ip = host.default_ip def __repr__(self): return '<%s at %s local_uri=%s>' % (type(self).__name__, hex(id(self)), getattr(self, 'local_uri', '(none)')) def prepare(self, local_uri=None): local_uri = local_uri or protocol.URI(port=0) local_uri.port = local_uri.port or 2855 self.local_uri = local_uri return [local_uri] def getHost(self): return IPv4Address('TCP', self.host_ip, 0) def complete(self, full_remote_path): with MSRPConnectTimeout.timeout(): msrp = self._connect(self.local_uri, full_remote_path[0]) # can't do the following, because local_uri was already used in the INVITE #msrp.local_uri.port = msrp.getHost().port try: with MSRPBindSessionTimeout.timeout(): msrp.bind(full_remote_path) except: msrp.loseConnection(wait=False) raise return msrp class DirectAcceptor(ConnectBase): def __init__(self, logger=Null, use_sessmatch=False): ConnectBase.__init__(self, logger, use_sessmatch) self.listening_port = None self.transport_event = None def __repr__(self): return '<%s at %s local_uri=%s listening_port=%s>' % (type(self).__name__, hex(id(self)), self.local_uri, self.listening_port) def prepare(self, local_uri=None): """Start listening for an incoming MSRP connection using port and use_tls from local_uri if provided. Return full local path, suitable to put in SDP a:path attribute. Note, that `local_uri' may be updated in place. """ local_uri = local_uri or protocol.URI(port=0) self.transport_event = coros.event() local_uri.host = gethostbyname(local_uri.host) factory = SpawnFactory(self.transport_event, MSRPTransport, local_uri, logger=self.logger, use_sessmatch=self.use_sessmatch) self.listening_port = self._listen(local_uri, factory) self.local_uri = local_uri return [local_uri] def getHost(self): return self.listening_port.getHost() def complete(self, full_remote_path): """Accept an incoming MSRP connection and bind it. Return MSRPTransport instance. """ try: with MSRPIncomingConnectTimeout.timeout(): msrp = self.transport_event.wait() remote_address = msrp.getPeer() self.logger.info('Incoming %s connection from %s:%s', self.local_uri.scheme.upper(), remote_address.host, remote_address.port) finally: self.cleanup() try: with MSRPBindSessionTimeout.timeout(): msrp.accept_binding(full_remote_path) except: msrp.loseConnection(wait=False) raise return msrp def cleanup(self, wait=True): if self.listening_port is not None: self.listening_port.stopListening() self.listening_port = None self.transport_event = None self.local_uri = None def _deliver_chunk(msrp, chunk): msrp.write_chunk(chunk) with MSRPAuthTimeout.timeout(): response = msrp.read_chunk() if response.method is not None: raise MSRPBadRequest if response.transaction_id!=chunk.transaction_id: raise MSRPBadRequest return response class RelayConnection(ConnectBase): def __init__(self, relay, mode, logger=Null, use_sessmatch=False): if mode not in ('active', 'passive'): raise ValueError("RelayConnection mode should be one of 'active' or 'passive'") ConnectBase.__init__(self, logger, use_sessmatch) self.mode = mode self.relay = relay self.msrp = None def __repr__(self): return '<%s at %s relay=%r msrp=%s>' % (type(self).__name__, hex(id(self)), self.relay, self.msrp) def _relay_connect(self): try: msrp = self._connect(self.local_uri, self.relay) except Exception: self.logger.info('Could not connect to relay %s', self.relay) raise local_address = msrp.getHost() remote_address = msrp.getPeer() try: self.local_uri.port = local_address.port msrpdata = protocol.MSRPData(method="AUTH", transaction_id='%x' % random.getrandbits(64)) msrpdata.add_header(protocol.ToPathHeader([self.relay.uri_domain])) msrpdata.add_header(protocol.FromPathHeader([self.local_uri])) response = _deliver_chunk(msrp, msrpdata) if response.code == 401: www_authenticate = response.headers["WWW-Authenticate"] auth, rsp_auth = process_www_authenticate(self.relay.username, self.relay.password, "AUTH", str(self.relay.uri_domain), **www_authenticate.decoded) msrpdata.transaction_id = '%x' % random.getrandbits(64) msrpdata.add_header(protocol.AuthorizationHeader(auth)) response = _deliver_chunk(msrp, msrpdata) if response.code != 200: raise MSRPRelayAuthError(comment=response.comment, code=response.code) msrp.set_local_path(list(response.headers["Use-Path"].decoded)) self.logger.info('Reserved session at relay %s:%s', remote_address.host, remote_address.port) except: self.logger.info('Could not reserve session at relay %s:%s', remote_address.host, remote_address.port) msrp.loseConnection(wait=False) raise return msrp def _relay_connect_timeout(self): with MSRPRelayConnectTimeout.timeout(): return self._relay_connect() def prepare(self, local_uri=None): self.local_uri = local_uri or protocol.URI(port=0) self.msrp = self._relay_connect_timeout() return self.msrp.full_local_path def getHost(self): return self.msrp.getHost() def cleanup(self, wait=True): if self.msrp is not None: self.msrp.loseConnection(wait=wait) self.msrp = None def complete(self, full_remote_path): try: with MSRPBindSessionTimeout.timeout(): if self.mode == 'active': self.msrp.bind(full_remote_path) else: self.msrp.accept_binding(full_remote_path) return self.msrp except: self.msrp.loseConnection(wait=False) raise finally: self.msrp = None class Notifier(coros.event): def wait(self): if self.ready(): self.reset() return coros.event.wait(self) def send(self, value=None, exc=None): if self.ready(): self.reset() return coros.event.send(self, value, exc=exc) class MSRPServer(ConnectBase): """Manage listening sockets. Bind incoming requests. MSRPServer solves the problem with AcceptorDirect: concurrent using of 2 or more AcceptorDirect instances on the same non-zero port is not possible. If you initialize() those instances, one after another, one will listen on the socket and another will get BindError. MSRPServer avoids the problem by sharing the listening socket between multiple connections. It has slightly different interface from AcceptorDirect, so it cannot be considered a drop-in replacement. """ CLOSE_TIMEOUT = MSRPBindSessionTimeout.seconds * 2 def __init__(self, logger): ConnectBase.__init__(self, logger) self.ports = {} # maps interface -> port -> (use_tls, listening Port) self.queue = coros.queue() self.expected_local_uris = {} # maps local_uri -> Logger instance self.expected_remote_paths = {} # maps full_remote_path -> event self.new_full_remote_path_notifier = Notifier() self.factory = SpawnFactory(self._incoming_handler, MSRPTransport, local_uri=None, logger=self.logger) def prepare(self, local_uri=None, logger=None): """Start a listening port specified by local_uri if there isn't one on that port/interface already. Add `local_uri' to the list of expected URIs, so that incoming connections featuring this URI won't be rejected. If `logger' is provided use it for this connection instead of the default one. """ local_uri = local_uri or protocol.URI(port=2855) need_listen = True if local_uri.port: use_tls, listening_port = self.ports.get(local_uri.host, {}).get(local_uri.port, (None, None)) if listening_port is not None: if use_tls==local_uri.use_tls: need_listen = False else: listening_port.stopListening() sleep(0) # make the reactor really stop listening, so that the next listen() call won't fail self.ports.pop(local_uri.host, {}).pop(local_uri.port, None) else: # caller does not care about port number for (use_tls, port) in self.ports[local_uri.host]: if local_uri.use_tls==use_tls: local_uri.port = port.getHost().port need_listen = False if need_listen: port = self._listen(local_uri, self.factory) self.ports.setdefault(local_uri.host, {})[local_uri.port] = (local_uri.use_tls, port) self.expected_local_uris[local_uri] = logger self.local_uri = local_uri return [local_uri] def _incoming_handler(self, msrp): peer = msrp.getPeer() self.logger.info('Incoming connection from %s:%s', peer.host, peer.port) with MSRPBindSessionTimeout.timeout(): chunk = msrp.read_chunk() to_path = chunk.to_path if len(to_path) != 1: msrp.write_response(chunk, 400, 'Invalid To-Path', wait=False) msrp.loseConnection(wait=False) return to_path = to_path[0] if to_path in self.expected_local_uris: logger = self.expected_local_uris.pop(to_path) if logger is not None: msrp.logger = logger msrp.local_uri = to_path else: msrp.write_response(chunk, 481, 'Unknown To-Path', wait=False) msrp.loseConnection(wait=False) return from_path = tuple(chunk.from_path) # at this point, must wait for complete() function to be called which will # provide an event for this full_remote_path while True: event = self.expected_remote_paths.pop(from_path, None) if event is not None: break self.new_full_remote_path_notifier.wait() if event is not None: msrp._set_full_remote_path(list(from_path)) error = msrp.check_incoming_SEND_chunk(chunk) else: error = MSRPNoSuchSessionError if error is None: msrp.write_response(chunk, 200, 'OK') if 'Content-Type' in chunk.headers or chunk.size>0: # chunk must be made available to read_chunk() again because it has payload raise NotImplementedError if event is not None: event.send(msrp) else: msrp.write_response(chunk, error.code, error.comment) def complete(self, full_remote_path): """Wait until one of the incoming connections binds using provided full_remote_path. Return connected and bound MSRPTransport instance. If no such binding was made within MSRPBindSessionTimeout.seconds, raise MSRPBindSessionTimeout. """ full_remote_path = tuple(full_remote_path) event = coros.event() self.expected_remote_paths[full_remote_path] = event try: self.new_full_remote_path_notifier.send() with MSRPBindSessionTimeout.timeout(): return event.wait() finally: self.expected_remote_paths.pop(full_remote_path, None) def cleanup(self, local_uri): """Remove `local_uri' from the list of expected URIs""" self.expected_local_uris.pop(local_uri, None) def stopListening(self): """Close all the sockets that MSRPServer is listening on""" for interface, rest in self.ports.iteritems(): for port, (use_tls, listening_port) in rest: listening_port.stopListening() self.ports = {} def close(self): """Stop listening. Wait for the spawned greenlets to finish""" self.stopListening() with timeout(self.CLOSE_TIMEOUT, None): self.factory.waitall()