diff --git a/msrplib/connect.py b/msrplib/connect.py index e429355..d053896 100644 --- a/msrplib/connect.py +++ b/msrplib/connect.py @@ -1,507 +1,508 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details + """Establish MSRP connection. This module provides means to obtain a connected and bound MSRPTransport instance. It uniformly handles 3 different configurations you may find your client engaged in: 1. Calling endpoint, not using a relay (DirectConnector) 2. Answering endpoint, not using a relay (DirectAcceptor) 3. Endpoint using a relay (RelayConnection) The answering endpoint may skip using the relay if sure that it's accessible directly. The calling endpoint is unlikely to need the relay. Once you have an instance of the right class, the procedure to establish the connection is the same: full_local_path = connector.prepare() try: ... put full_local_path in SDP 'a:path' attribute ... get full_remote_path from remote's 'a:path: attribute ... (the order of the above steps is reversed if you're the ... answering party, but that does not affect connector's usage) msrptransport = connector.complete(full_remote_path) finally: connector.cleanup() To customize connection's parameters, create a new protocol.URI object and pass it to prepare() function, e.g. local_uri = protocol.URI(use_tls=False, port=5000) connector.prepare(local_uri) prepare() may update local_uri in place with the actual connection parameters used (e.g. if you specified port=0). 'port' attribute of local_uri is currently only respected by AcceptorDirect. Note that, acceptors and connectors are one-use only. MSRPServer, on the contrary, can be used multiple times. """ from __future__ import with_statement import random from twisted.internet.address import IPv4Address from twisted.names.srvconnect import SRVConnector from application.python import Null from application.system import host from eventlib.twistedutil.protocol import GreenClientCreator, SpawnFactory from eventlib import coros from eventlib.api import timeout, sleep from eventlib.green.socket import gethostbyname from gnutls.interfaces.twisted import TLSContext from msrplib import protocol, MSRPError from msrplib.transport import MSRPTransport, MSRPTransactionError, MSRPBadRequest, MSRPNoSuchSessionError from msrplib.digest import process_www_authenticate __all__ = ['MSRPRelaySettings', 'MSRPTimeout', 'MSRPConnectTimeout', 'MSRPRelayConnectTimeout', 'MSRPIncomingConnectTimeout', 'MSRPBindSessionTimeout', 'MSRPRelayAuthError', 'MSRPAuthTimeout', 'MSRPServer', 'DirectConnector', 'DirectAcceptor', 'RelayConnection'] class MSRPRelaySettings(protocol.ConnectInfo): use_tls = True def __init__(self, domain, username, password, host=None, port=None, use_tls=None, credentials=None): protocol.ConnectInfo.__init__(self, host, use_tls=use_tls, port=port, credentials=credentials) self.domain = domain self.username = username self.password = password def __str__(self): result = "MSRPRelay %s://%s" % (self.scheme, self.host or self.domain) if self.port: result += ':%s' % self.port return result def __repr__(self): params = [self.domain, self.username, self.password, self.host, self.port] if params[-1] is None: del params[-1] if params[-1] is None: del params[-1] return '%s(%s)' % (type(self).__name__, ', '.join(repr(x) for x in params)) @property def uri_domain(self): return protocol.URI(host=self.domain, port=self.port, use_tls=self.use_tls, session_id='') class TimeoutMixin(object): @classmethod def timeout(cls, *throw_args): if not throw_args: throw_args = (cls, ) return timeout(cls.seconds, *throw_args) class MSRPTimeout(MSRPError, TimeoutMixin): seconds = 30 class MSRPConnectTimeout(MSRPTimeout): pass class MSRPRelayConnectTimeout(MSRPTimeout): pass class MSRPIncomingConnectTimeout(MSRPTimeout): pass class MSRPBindSessionTimeout(MSRPTimeout): pass class MSRPRelayAuthError(MSRPTransactionError): pass class MSRPAuthTimeout(MSRPTransactionError, TimeoutMixin): code = 408 comment = 'No response to AUTH request' seconds = 30 class MSRPSRVConnector(SRVConnector): def pickServer(self): assert self.servers is not None assert self.orderedServers is not None if not self.servers and not self.orderedServers: # no SRV record, fall back.. return self.domain, 2855 return SRVConnector.pickServer(self) class ConnectBase(object): SRVConnectorClass = MSRPSRVConnector def __init__(self, logger=Null, use_sessmatch=False): self.logger = logger self.use_sessmatch = use_sessmatch self.local_uri = None def _connect(self, local_uri, remote_uri): self.logger.info('Connecting to %s' % (remote_uri, )) creator = GreenClientCreator(gtransport_class=MSRPTransport, local_uri=local_uri, logger=self.logger, use_sessmatch=self.use_sessmatch) if remote_uri.host: if remote_uri.use_tls: msrp = creator.connectTLS(remote_uri.host, remote_uri.port or 2855, TLSContext(local_uri.credentials)) else: msrp = creator.connectTCP(remote_uri.host, remote_uri.port or 2855) else: if not remote_uri.domain: raise ValueError("remote_uri must have either 'host' or 'domain'") if remote_uri.use_tls: connectFuncName = 'connectTLS' connectFuncArgs = (TLSContext(local_uri.credentials),) else: connectFuncName = 'connectTCP' connectFuncArgs = () msrp = creator.connectSRV(remote_uri.scheme, remote_uri.domain, connectFuncName=connectFuncName, connectFuncArgs=connectFuncArgs, ConnectorClass=self.SRVConnectorClass) self.logger.info('Connected to %s:%s' % (msrp.getPeer().host, msrp.getPeer().port)) return msrp def _listen(self, local_uri, factory): from twisted.internet import reactor if local_uri.use_tls: port = reactor.listenTLS(local_uri.port or 0, factory, TLSContext(local_uri.credentials), interface=local_uri.host) else: port = reactor.listenTCP(local_uri.port or 0, factory, interface=local_uri.host) local_uri.port = port.getHost().port self.logger.info('Listening for incoming %s connections on %s:%s' % (local_uri.scheme.upper(), port.getHost().host, port.getHost().port)) return port def cleanup(self, wait=True): pass class DirectConnector(ConnectBase): def __init__(self, logger=Null, use_sessmatch=False): ConnectBase.__init__(self, logger, use_sessmatch) self.host_ip = host.default_ip def __repr__(self): return '<%s at %s local_uri=%s>' % (type(self).__name__, hex(id(self)), getattr(self, 'local_uri', '(none)')) def prepare(self, local_uri=None): local_uri = local_uri or protocol.URI(port=0) local_uri.port = local_uri.port or 2855 self.local_uri = local_uri return [local_uri] def getHost(self): return IPv4Address('TCP', self.host_ip, 0) def complete(self, full_remote_path): with MSRPConnectTimeout.timeout(): msrp = self._connect(self.local_uri, full_remote_path[0]) # can't do the following, because local_uri was already used in the INVITE #msrp.local_uri.port = msrp.getHost().port try: with MSRPBindSessionTimeout.timeout(): msrp.bind(full_remote_path) except: msrp.loseConnection(wait=False) raise return msrp class DirectAcceptor(ConnectBase): def __init__(self, logger=Null, use_sessmatch=False): ConnectBase.__init__(self, logger, use_sessmatch) self.listening_port = None self.transport_event = None def __repr__(self): return '<%s at %s local_uri=%s listening_port=%s>' % (type(self).__name__, hex(id(self)), self.local_uri, self.listening_port) def prepare(self, local_uri=None): """Start listening for an incoming MSRP connection using port and use_tls from local_uri if provided. Return full local path, suitable to put in SDP a:path attribute. Note, that `local_uri' may be updated in place. """ local_uri = local_uri or protocol.URI(port=0) self.transport_event = coros.event() local_uri.host = gethostbyname(local_uri.host) factory = SpawnFactory(self.transport_event, MSRPTransport, local_uri, logger=self.logger, use_sessmatch=self.use_sessmatch) self.listening_port = self._listen(local_uri, factory) self.local_uri = local_uri return [local_uri] def getHost(self): return self.listening_port.getHost() def complete(self, full_remote_path): """Accept an incoming MSRP connection and bind it. Return MSRPTransport instance. """ try: with MSRPIncomingConnectTimeout.timeout(): msrp = self.transport_event.wait() msg = 'Incoming %s connection from %s:%s' % (self.local_uri.scheme.upper(), msrp.getPeer().host, msrp.getPeer().port) self.logger.info(msg) finally: self.cleanup() try: with MSRPBindSessionTimeout.timeout(): msrp.accept_binding(full_remote_path) except: msrp.loseConnection(wait=False) raise return msrp def cleanup(self, wait=True): if self.listening_port is not None: self.listening_port.stopListening() self.listening_port = None self.transport_event = None self.local_uri = None def _deliver_chunk(msrp, chunk): msrp.write_chunk(chunk) with MSRPAuthTimeout.timeout(): response = msrp.read_chunk() if response.method is not None: raise MSRPBadRequest if response.transaction_id!=chunk.transaction_id: raise MSRPBadRequest return response class RelayConnection(ConnectBase): def __init__(self, relay, mode, logger=Null, use_sessmatch=False): if mode not in ('active', 'passive'): raise ValueError("RelayConnection mode should be one of 'active' or 'passive'") ConnectBase.__init__(self, logger, use_sessmatch) self.mode = mode self.relay = relay self.msrp = None def __repr__(self): return '<%s at %s relay=%r msrp=%s>' % (type(self).__name__, hex(id(self)), self.relay, self.msrp) def _relay_connect(self): try: msrp = self._connect(self.local_uri, self.relay) except Exception: self.logger.info('Could not connect to relay %s' % self.relay) raise try: self.local_uri.port = msrp.getHost().port msrpdata = protocol.MSRPData(method="AUTH", transaction_id='%x' % random.getrandbits(64)) msrpdata.add_header(protocol.ToPathHeader([self.relay.uri_domain])) msrpdata.add_header(protocol.FromPathHeader([self.local_uri])) response = _deliver_chunk(msrp, msrpdata) if response.code == 401: www_authenticate = response.headers["WWW-Authenticate"] auth, rsp_auth = process_www_authenticate(self.relay.username, self.relay.password, "AUTH", str(self.relay.uri_domain), **www_authenticate.decoded) msrpdata.transaction_id = '%x' % random.getrandbits(64) msrpdata.add_header(protocol.AuthorizationHeader(auth)) response = _deliver_chunk(msrp, msrpdata) if response.code != 200: raise MSRPRelayAuthError(comment=response.comment, code=response.code) msrp.set_local_path(list(response.headers["Use-Path"].decoded)) msg = 'Reserved session at relay %s:%s' % (msrp.getPeer().host, msrp.getPeer().port) self.logger.info(msg) except: msg = 'Could not reserve session at relay %s:%s' % (msrp.getPeer().host, msrp.getPeer().port) self.logger.info(msg) msrp.loseConnection(wait=False) raise return msrp def _relay_connect_timeout(self): with MSRPRelayConnectTimeout.timeout(): return self._relay_connect() def prepare(self, local_uri=None): self.local_uri = local_uri or protocol.URI(port=0) self.msrp = self._relay_connect_timeout() return self.msrp.full_local_path def getHost(self): return self.msrp.getHost() def cleanup(self, wait=True): if self.msrp is not None: self.msrp.loseConnection(wait=wait) self.msrp = None def complete(self, full_remote_path): try: with MSRPBindSessionTimeout.timeout(): if self.mode == 'active': self.msrp.bind(full_remote_path) else: self.msrp.accept_binding(full_remote_path) return self.msrp except: self.msrp.loseConnection(wait=False) raise finally: self.msrp = None class Notifier(coros.event): def wait(self): if self.ready(): self.reset() return coros.event.wait(self) def send(self, value=None, exc=None): if self.ready(): self.reset() return coros.event.send(self, value, exc=exc) class MSRPServer(ConnectBase): """Manage listening sockets. Bind incoming requests. MSRPServer solves the problem with AcceptorDirect: concurrent using of 2 or more AcceptorDirect instances on the same non-zero port is not possible. If you initialize() those instances, one after another, one will listen on the socket and another will get BindError. MSRPServer avoids the problem by sharing the listening socket between multiple connections. It has slightly different interface from AcceptorDirect, so it cannot be considered a drop-in replacement. """ CLOSE_TIMEOUT = MSRPBindSessionTimeout.seconds * 2 def __init__(self, logger): ConnectBase.__init__(self, logger) self.ports = {} # maps interface -> port -> (use_tls, listening Port) self.queue = coros.queue() self.expected_local_uris = {} # maps local_uri -> Logger instance self.expected_remote_paths = {} # maps full_remote_path -> event self.new_full_remote_path_notifier = Notifier() self.factory = SpawnFactory(self._incoming_handler, MSRPTransport, local_uri=None, logger=self.logger) def prepare(self, local_uri=None, logger=None): """Start a listening port specified by local_uri if there isn't one on that port/interface already. Add `local_uri' to the list of expected URIs, so that incoming connections featuring this URI won't be rejected. If `logger' is provided use it for this connection instead of the default one. """ local_uri = local_uri or protocol.URI(port=2855) need_listen = True if local_uri.port: use_tls, listening_port = self.ports.get(local_uri.host, {}).get(local_uri.port, (None, None)) if listening_port is not None: if use_tls==local_uri.use_tls: need_listen = False else: listening_port.stopListening() sleep(0) # make the reactor really stop listening, so that the next listen() call won't fail self.ports.pop(local_uri.host, {}).pop(local_uri.port, None) else: # caller does not care about port number for (use_tls, port) in self.ports[local_uri.host]: if local_uri.use_tls==use_tls: local_uri.port = port.getHost().port need_listen = False if need_listen: port = self._listen(local_uri, self.factory) self.ports.setdefault(local_uri.host, {})[local_uri.port] = (local_uri.use_tls, port) self.expected_local_uris[local_uri] = logger self.local_uri = local_uri return [local_uri] def _incoming_handler(self, msrp): msg = 'Incoming connection from %s:%s' % (msrp.getPeer().host, msrp.getPeer().port) self.logger.info(msg) with MSRPBindSessionTimeout.timeout(): chunk = msrp.read_chunk() - ToPath = tuple(chunk.headers['To-Path'].decoded) - if len(ToPath)!=1: + to_path = chunk.to_path + if len(to_path) != 1: msrp.write_response(chunk, 400, 'Invalid To-Path', wait=False) msrp.loseConnection(wait=False) return - ToPath = ToPath[0] - if ToPath in self.expected_local_uris: - logger = self.expected_local_uris.pop(ToPath) + to_path = to_path[0] + if to_path in self.expected_local_uris: + logger = self.expected_local_uris.pop(to_path) if logger is not None: msrp.logger = logger - msrp.local_uri = ToPath + msrp.local_uri = to_path else: msrp.write_response(chunk, 481, 'Unknown To-Path', wait=False) msrp.loseConnection(wait=False) return - FromPath = tuple(chunk.headers['From-Path'].decoded) + from_path = tuple(chunk.from_path) # at this point, must wait for complete() function to be called which will # provide an event for this full_remote_path while True: - event = self.expected_remote_paths.pop(FromPath, None) + event = self.expected_remote_paths.pop(from_path, None) if event is not None: break self.new_full_remote_path_notifier.wait() if event is not None: - msrp._set_full_remote_path(list(FromPath)) + msrp._set_full_remote_path(list(from_path)) error = msrp.check_incoming_SEND_chunk(chunk) else: error = MSRPNoSuchSessionError if error is None: msrp.write_response(chunk, 200, 'OK') if 'Content-Type' in chunk.headers or chunk.size>0: # chunk must be made available to read_chunk() again because it has payload raise NotImplementedError if event is not None: event.send(msrp) else: msrp.write_response(chunk, error.code, error.comment) def complete(self, full_remote_path): """Wait until one of the incoming connections binds using provided full_remote_path. Return connected and bound MSRPTransport instance. If no such binding was made within MSRPBindSessionTimeout.seconds, raise MSRPBindSessionTimeout. """ full_remote_path = tuple(full_remote_path) event = coros.event() self.expected_remote_paths[full_remote_path] = event try: self.new_full_remote_path_notifier.send() with MSRPBindSessionTimeout.timeout(): return event.wait() finally: self.expected_remote_paths.pop(full_remote_path, None) def cleanup(self, local_uri): """Remove `local_uri' from the list of expected URIs""" self.expected_local_uris.pop(local_uri, None) def stopListening(self): """Close all the sockets that MSRPServer is listening on""" for interface, rest in self.ports.iteritems(): for port, (use_tls, listening_port) in rest: listening_port.stopListening() self.ports = {} def close(self): """Stop listening. Wait for the spawned greenlets to finish""" self.stopListening() with timeout(self.CLOSE_TIMEOUT, None): self.factory.waitall() diff --git a/msrplib/protocol.py b/msrplib/protocol.py index 5129df6..9d2ac03 100644 --- a/msrplib/protocol.py +++ b/msrplib/protocol.py @@ -1,666 +1,743 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details import random import re -from collections import deque +from collections import deque, namedtuple from application.system import host as host_module from twisted.internet.protocol import connectionDone from twisted.protocols.basic import LineReceiver from msrplib import MSRPError class ParsingError(MSRPError): pass -class HeaderParsingError(ParsingError): - def __init__(self, header): +class HeaderParsingError(ParsingError): + def __init__(self, header, reason): self.header = header - ParsingError.__init__(self, "Error parsing %s header" % header) + ParsingError.__init__(self, 'Error parsing {} header: {}'.format(header, reason)) + + +# Header value data types (for decoded values) +# + +ByteRange = namedtuple('ByteRange', ['start', 'end', 'total']) +Status = namedtuple('Status', ['code', 'comment']) +ContentDisposition = namedtuple('ContentDisposition', ['disposition', 'parameters']) + + +# Header value types (describe how to encode/decode the value) +# + +class SimpleHeaderType(object): + data_type = object + + @staticmethod + def decode(encoded): + return encoded + + @staticmethod + def encode(decoded): + return decoded + + +class UTF8HeaderType(object): + data_type = unicode + + @staticmethod + def decode(encoded): + return encoded.decode('utf-8') + + @staticmethod + def encode(decoded): + return decoded.encode('utf-8') + + +class URIHeaderType(object): + data_type = deque + + @staticmethod + def decode(encoded): + return deque(parse_uri(uri) for uri in encoded.split(' ')) + + @staticmethod + def encode(decoded): + return ' '.join(str(uri) for uri in decoded) + + +class IntegerHeaderType(object): + data_type = int + + @staticmethod + def decode(encoded): + return int(encoded) + + @staticmethod + def encode(decoded): + return str(decoded) + + +class LimitedChoiceHeaderType(SimpleHeaderType): + allowed_values = frozenset() + + @classmethod + def decode(cls, encoded): + if encoded not in cls.allowed_values: + raise ValueError('Invalid value: {!r}'.format(encoded)) + return encoded + + +class SuccessReportHeaderType(LimitedChoiceHeaderType): + allowed_values = frozenset({'yes', 'no'}) + + +class FailureReportHeaderType(LimitedChoiceHeaderType): + allowed_values = frozenset({'yes', 'no', 'partial'}) + + +class ByteRangeHeaderType(object): + data_type = ByteRange + + regex = re.compile(r'(?P\d+)-(?P\*|\d+)/(?P\*|\d+)') + + @classmethod + def decode(cls, encoded): + match = cls.regex.match(encoded) + if match is None: + raise ValueError('Invalid byte range value: {!r}'.format(encoded)) + start, end, total = match.groups() + start = int(start) + end = int(end) if end != '*' else None + total = int(total) if total != '*' else None + return ByteRange(start, end, total) + + @staticmethod + def encode(decoded): + start, end, total = decoded + return '{}-{}/{}'.format(start, end or '*', total or '*') + + +class StatusHeaderType(object): + data_type = Status + + @staticmethod + def decode(encoded): + namespace, sep, rest = encoded.partition(' ') + if namespace != '000' or sep != ' ': + raise ValueError('Invalid status value: {!r}'.format(encoded)) + code, _, comment = rest.partition(' ') + if not code.isdigit() or len(code) != 3: + raise ValueError('Invalid status code: {!r}'.format(code)) + return Status(int(code), comment or None) + + @staticmethod + def encode(decoded): + code, comment = decoded + if comment is None: + return '000 {:03d}'.format(code) + else: + return '000 {:03d} {}'.format(code, comment) + + +class ContentDispositionHeaderType(object): + data_type = ContentDisposition + + regex = re.compile(r'(\w+)=("[^"]+"|[^";]+)') + + @classmethod + def decode(cls, encoded): + disposition, _, parameters = encoded.partition(';') + if not disposition: + raise ValueError('Invalid content disposition: {!r}'.format(encoded)) + return ContentDisposition(disposition, {name: value.strip('"') for name, value in cls.regex.findall(parameters)}) + + @staticmethod + def encode(decoded): + disposition, parameters = decoded + return '; '.join([disposition] + ['{}="{}"'.format(name, value) for name, value in parameters.iteritems()]) + + +class ParameterListHeaderType(object): + data_type = dict + + regex = re.compile(r'(\w+)=("[^"]+"|[^",]+)') + + @classmethod + def decode(cls, encoded): + return {name: value.strip('"') for name, value in cls.regex.findall(encoded)} + + @staticmethod + def encode(decoded): + return ', '.join('{}="{}"'.format(name, value) for name, value in decoded.iteritems()) + + +class DigestHeaderType(ParameterListHeaderType): + @classmethod + def decode(cls, encoded): + algorithm, sep, parameters = encoded.partition(' ') + if algorithm != 'Digest' or sep != ' ': + raise ValueError('Invalid Digest header value') + return super(DigestHeaderType, cls).decode(parameters) + + @staticmethod + def encode(decoded): + return 'Digest ' + super(DigestHeaderType, DigestHeaderType).encode(decoded) + + +# Header classes +# class MSRPHeaderMeta(type): - header_classes = {} + __classmap__ = {} + + name = None + + def __init__(cls, name, bases, dictionary): + type.__init__(cls, name, bases, dictionary) + if cls.name is not None: + cls.__classmap__[cls.name] = cls + + def __call__(cls, *args, **kw): + if cls.name is not None: + return super(MSRPHeaderMeta, cls).__call__(*args, **kw) # specialized class, instantiated directly. + else: + return cls._instantiate_specialized_class(*args, **kw) # non-specialized class, instantiated as a more specialized class if available. + + def _instantiate_specialized_class(cls, name, value): + if name in cls.__classmap__: + return super(MSRPHeaderMeta, cls.__classmap__[name]).__call__(value) + else: + return super(MSRPHeaderMeta, cls).__call__(name, value) - def __init__(cls, name, bases, dict): - type.__init__(cls, name, bases, dict) - try: - cls.header_classes[dict['name']] = cls - except KeyError: - pass class MSRPHeader(object): __metaclass__ = MSRPHeaderMeta - def __new__(cls, name, value): - if isinstance(value, str) and name in MSRPHeaderMeta.header_classes: - cls = MSRPHeaderMeta.header_classes[name] - return object.__new__(cls) + name = None + type = SimpleHeaderType def __init__(self, name, value): self.name = name if isinstance(value, str): self.encoded = value else: self.decoded = value - def _raise_error(self): - raise HeaderParsingError(self.name) + def __eq__(self, other): + if isinstance(other, MSRPHeader): + return self.name == other.name and self.decoded == other.decoded + return NotImplemented + + def __ne__(self, other): + return not self == other - def _get_encoded(self): + @property + def encoded(self): if self._encoded is None: - self._encoded = self._encode(self._decoded) + self._encoded = self.type.encode(self._decoded) return self._encoded - def _set_encoded(self, encoded): + @encoded.setter + def encoded(self, encoded): self._encoded = encoded self._decoded = None - encoded = property(_get_encoded, _set_encoded) - - def _get_decoded(self): + @property + def decoded(self): if self._decoded is None: - self._decoded = self._decode(self._encoded) + try: + self._decoded = self.type.decode(self._encoded) + except Exception as e: + raise HeaderParsingError(self.name, str(e)) return self._decoded - def _set_decoded(self, decoded): + @decoded.setter + def decoded(self, decoded): + if not isinstance(decoded, self.type.data_type): + try: + # noinspection PyArgumentList + decoded = self.type.data_type(decoded) + except Exception: + raise TypeError('value must be an instance of {}'.format(self.type.data_type.__name__)) self._decoded = decoded self._encoded = None - decoded = property(_get_decoded, _set_decoded) - - def _decode(self, encoded): - return encoded - - def _encode(self, decoded): - return decoded class MSRPNamedHeader(MSRPHeader): - - def __new__(cls, *args): - if len(args) == 1: - value = args[0] - else: - value = args[1] - return MSRPHeader.__new__(cls, cls.name, value) - - def __init__(self, *args): - if len(args) == 1: - value = args[0] - else: - value = args[1] + def __init__(self, value): MSRPHeader.__init__(self, self.name, value) -class URIHeader(MSRPNamedHeader): - - def _decode(self, encoded): - try: - return deque(parse_uri(uri) for uri in encoded.split(" ")) - except ParsingError: - self._raise_error() - - def _encode(self, decoded): - return " ".join([str(uri) for uri in decoded]) - -class IntegerHeader(MSRPNamedHeader): - - def _decode(self, encoded): - try: - return int(encoded) - except ValueError: - self._raise_error() - - def _encode(self, decoded): - return str(decoded) - -class DigestHeader(MSRPNamedHeader): - def _decode(self, encoded): - try: - algo, params = encoded.split(" ", 1) - except ValueError: - self._raise_error() - if algo != "Digest": - self._raise_error() - try: - param_dict = dict((x.strip('"') for x in param.split("=", 1)) for param in params.split(", ")) - except: - self._raise_error() - return param_dict +class ToPathHeader(MSRPNamedHeader): + name = 'To-Path' + type = URIHeaderType - def _encode(self, decoded): - return "Digest " + ", ".join(['%s="%s"' % tup for tup in decoded.iteritems()]) -class ToPathHeader(URIHeader): - name = "To-Path" +class FromPathHeader(MSRPNamedHeader): + name = 'From-Path' + type = URIHeaderType -class FromPathHeader(URIHeader): - name = "From-Path" class MessageIDHeader(MSRPNamedHeader): - name = "Message-ID" + name = 'Message-ID' + type = SimpleHeaderType + class SuccessReportHeader(MSRPNamedHeader): - name = "Success-Report" + name = 'Success-Report' + type = SuccessReportHeaderType - def _decode(self, encoded): - if encoded not in ["yes", "no"]: - self._raise_error() - return encoded class FailureReportHeader(MSRPNamedHeader): - name = "Failure-Report" + name = 'Failure-Report' + type = FailureReportHeaderType - def _decode(self, encoded): - if encoded not in ["yes", "no", "partial"]: - self._raise_error() - return encoded class ByteRangeHeader(MSRPNamedHeader): - name = "Byte-Range" - - def _decode(self, encoded): - try: - rest, total = encoded.split("/") - fro, to = rest.split("-") - fro = int(fro) - except ValueError: - self._raise_error() - try: - to = int(to) - except ValueError: - if to != "*": - self._raise_error() - to = None - try: - total = int(total) - except ValueError: - if total != "*": - self._raise_error() - total = None - return (fro, to, total) - - def _encode(self, decoded): - fro, to, total = decoded - if to is None: - to = "*" - if total is None: - total = "*" - return "%s-%s/%s" % (fro, to, total) + name = 'Byte-Range' + type = ByteRangeHeaderType @property - def fro(self): - return self.decoded[0] + def start(self): + return self.decoded.start @property - def to(self): - return self.decoded[1] + def end(self): + return self.decoded.end @property def total(self): - return self.decoded[2] - -class StatusHeader(MSRPNamedHeader): - name = "Status" + return self.decoded.total - def _decode(self, encoded): - try: - namespace, rest = encoded.split(" ", 1) - except ValueError: - self._raise_error() - if namespace != "000": - self._raise_error() - rest_sp = rest.split(" ", 1) - try: - if len(rest_sp[0]) != 3: - raise ValueError - code = int(rest_sp[0]) - except ValueError: - self._raise_error() - try: - comment = rest_sp[1] - except IndexError: - comment = None - return (code, comment) - def _encode(self, decoded): - code, comment = decoded - encoded = "000 %03d" % code - if comment is not None: - encoded += " %s" % comment - return encoded +class StatusHeader(MSRPNamedHeader): + name = 'Status' + type = StatusHeaderType @property def code(self): - return self.decoded[0] + return self.decoded.code @property def comment(self): - return self.decoded[1] + return self.decoded.comment -class ExpiresHeader(IntegerHeader): - name = "Expires" -class MinExpiresHeader(IntegerHeader): - name = "Min-Expires" +class ExpiresHeader(MSRPNamedHeader): + name = 'Expires' + type = IntegerHeaderType -class MaxExpiresHeader(IntegerHeader): - name = "Max-Expires" -class UsePathHeader(URIHeader): - name = "Use-Path" +class MinExpiresHeader(MSRPNamedHeader): + name = 'Min-Expires' + type = IntegerHeaderType -class WWWAuthenticateHeader(DigestHeader): - name = "WWW-Authenticate" -class AuthorizationHeader(DigestHeader): - name = "Authorization" +class MaxExpiresHeader(MSRPNamedHeader): + name = 'Max-Expires' + type = IntegerHeaderType -class AuthenticationInfoHeader(MSRPNamedHeader): - name = "Authentication-Info" - def _decode(self, encoded): - try: - param_dict = dict((x.strip('"') for x in param.split("=", 1)) for param in encoded.split(", ")) - except: - self._raise_error() - return param_dict +class UsePathHeader(MSRPNamedHeader): + name = 'Use-Path' + type = URIHeaderType + + +class WWWAuthenticateHeader(MSRPNamedHeader): + name = 'WWW-Authenticate' + type = DigestHeaderType + + +class AuthorizationHeader(MSRPNamedHeader): + name = 'Authorization' + type = DigestHeaderType + + +class AuthenticationInfoHeader(MSRPNamedHeader): + name = 'Authentication-Info' + type = ParameterListHeaderType - def _encode(self, decoded): - return ", ".join(['%s="%s"' % tup for tup in decoded.iteritems()]) class ContentTypeHeader(MSRPNamedHeader): - name = "Content-Type" + name = 'Content-Type' + type = SimpleHeaderType + class ContentIDHeader(MSRPNamedHeader): - name = "Content-ID" + name = 'Content-ID' + type = SimpleHeaderType + class ContentDescriptionHeader(MSRPNamedHeader): - name = "Content-Description" + name = 'Content-Description' + type = SimpleHeaderType + class ContentDispositionHeader(MSRPNamedHeader): - name = "Content-Disposition" + name = 'Content-Disposition' + type = ContentDispositionHeaderType - def _decode(self, encoded): - try: - sp = encoded.split(";") - disposition = sp[0] - parameters = dict(param.split("=", 1) for param in sp[1:]) - except: - self._raise_error() - return [disposition, parameters] - - def _encode(self, decoded): - disposition, parameters = decoded - return ";".join([disposition] + ["%s=%s" % pair for pair in parameters.iteritems()]) class UseNicknameHeader(MSRPNamedHeader): - name = "Use-Nickname" + name = 'Use-Nickname' + type = UTF8HeaderType - def _decode(self, encoded): - return encoded.decode('utf-8') - def _encode(self, decoded): - return decoded.encode('utf-8') +class HeaderOrderMapping(dict): + __levels__ = { + 0: ['To-Path'], + 1: ['From-Path'], + 2: ['Status', 'Message-ID', 'Byte-Range', 'Success-Report', 'Failure-Report'] + + ['Authorization', 'Authentication-Info', 'WWW-Authenticate', 'Expires', 'Min-Expires', 'Max-Expires', 'Use-Path', 'Use-Nickname'], + 3: ['Content-ID', 'Content-Description', 'Content-Disposition'], + 4: ['Content-Type'] + } + + def __init__(self): + super(HeaderOrderMapping, self).__init__({name: level for level, name_list in self.__levels__.items() for name in name_list}) + + def __missing__(self, key): + return 3 if key.startswith('Content-') else 2 + + sort_key = dict.__getitem__ + + +class HeaderOrdering(object): + name_map = HeaderOrderMapping() + sort_key = name_map.sort_key + + +class MissingHeader(object): + decoded = None class MSRPData(object): def __init__(self, transaction_id, method=None, code=None, comment=None, headers=None, data='', contflag='$'): if method is None and code is None: raise ValueError('either method or code must be specified') elif method is not None and code is not None: raise ValueError('method and code cannot be both specified') elif code is None and comment is not None: raise ValueError('comment should only be specified when code is specified') self.transaction_id = transaction_id self.method = method self.code = code self.comment = comment self.headers = headers or {} self.data = data self.contflag = contflag + if method is not None: + self._first_line = 'MSRP {} {}'.format(transaction_id, method) + elif comment is None: + self._first_line = 'MSRP {} {:03d}'.format(transaction_id, code) + else: + self._first_line = 'MSRP {} {:03d} {}'.format(transaction_id, code, comment) def copy(self): - chunk = self.__class__(self.transaction_id) - chunk.__dict__.update(self.__dict__) - chunk.headers = dict(self.headers.items()) - return chunk + return self.__class__(self.transaction_id, self.method, self.code, self.comment, self.headers.copy(), self.data, self.contflag) - def __str__(self): - if self.method is None: - description = "MSRP %s %s" % (self.transaction_id, self.code) - if self.comment is not None: - description += " %s" % self.comment - else: - description = "MSRP %s %s" % (self.transaction_id, self.method) - return description + def __str__(self): # TODO: make __str__ == encode()? + return self._first_line def __repr__(self): - klass = type(self).__name__ - if self.method is None: - description = "%s %s" % (self.transaction_id, self.code) - if self.comment is not None: - description += " %s" % self.comment - else: - description = "%s %s" % (self.transaction_id, self.method) - if self.message_id is not None: - description += ' Message-ID=%s' % self.message_id - for key, value in self.headers.items(): - description += ' %s=%r' % (key, value.encoded) - description += ' len=%s' % self.size - return '<%s at %s %s %s>' % (klass, hex(id(self)), description, self.contflag) + description = self._first_line + for name in sorted(self.headers, key=HeaderOrdering.sort_key): + description += ' {}={!r}'.format(name, self.headers[name].encoded) + description += ' len={}'.format(self.size) + return '<{} at {:#x} {} {}>'.format(self.__class__.__name__, id(self), description, self.contflag) def __eq__(self, other): if not isinstance(other, MSRPData): return False return self.encode()==other.encode() def add_header(self, header): self.headers[header.name] = header def verify_headers(self): - try: # Decode To-/From-path headers first to be able to send responses - self.headers["To-Path"].decoded - self.headers["From-Path"].decoded - except KeyError, e: - raise HeaderParsingError(e.args[0]) + if 'To-Path' not in self.headers: + raise HeaderParsingError('To-Path', 'header is missing') + if 'From-Path' not in self.headers: + raise HeaderParsingError('From-Path', 'header is missing') for header in self.headers.itervalues(): - header.decoded + _ = header.decoded + + @property + def from_path(self): + return self.headers.get('From-Path', MissingHeader).decoded + + @property + def to_path(self): + return self.headers.get('To-Path', MissingHeader).decoded @property def content_type(self): - x = self.headers.get('Content-Type') - if x is None: - return x - return x.decoded + return self.headers.get('Content-Type', MissingHeader).decoded @property def message_id(self): - x = self.headers.get('Message-ID') - if x is None: - return x - return x.decoded + return self.headers.get('Message-ID', MissingHeader).decoded @property def byte_range(self): - x = self.headers.get('Byte-Range') - if x is None: - return x - return x.decoded + return self.headers.get('Byte-Range', MissingHeader).decoded @property def status(self): - return self.headers.get('Status') + return self.headers.get('Status', MissingHeader).decoded @property def failure_report(self): - if "Failure-Report" in self.headers: - return self.headers["Failure-Report"].decoded - else: - return "yes" + return self.headers.get('Failure-Report', MissingHeader).decoded or 'yes' @property def success_report(self): - if "Success-Report" in self.headers: - return self.headers["Success-Report"].decoded - else: - return "no" + return self.headers.get('Success-Report', MissingHeader).decoded or 'no' @property def size(self): return len(self.data) def encode_start(self): - data = [] - if self.method is not None: - data.append("MSRP %(transaction_id)s %(method)s" % self.__dict__) - else: - data.append("MSRP %(transaction_id)s %(code)03d" % self.__dict__ + (self.comment is not None and " %s" % self.comment or "")) - headers = self.headers.copy() - data.append("To-Path: %s" % headers.pop("To-Path").encoded) - data.append("From-Path: %s" % headers.pop("From-Path").encoded) - for hnameval in [(hname, headers.pop(hname).encoded) for hname in headers.keys() if not hname.startswith("Content-")]: - data.append("%s: %s" % hnameval) - for hnameval in [(hname, headers.pop(hname).encoded) for hname in headers.keys() if hname != "Content-Type"]: - data.append("%s: %s" % hnameval) - if len(headers) > 0: - data.append("Content-Type: %s" % headers["Content-Type"].encoded) - data.append("") - data.append("") - return "\r\n".join(data) + lines = [self._first_line] + ['{}: {}'.format(name, self.headers[name].encoded) for name in sorted(self.headers, key=HeaderOrdering.sort_key)] + if 'Content-Type' in self.headers: + lines.append('\r\n') + return '\r\n'.join(lines) def encode_end(self, continuation): - return "\r\n-------%s%s\r\n" % (self.transaction_id, continuation) + return '\r\n-------%s%s\r\n' % (self.transaction_id, continuation) def encode(self): return self.encode_start() + self.data + self.encode_end(self.contflag) class MSRPProtocol(LineReceiver): MAX_LENGTH = 16384 MAX_LINES = 64 def __init__(self, msrp_transport): self.msrp_transport = msrp_transport self.term_buf = '' self.term_re = None self.term_substrings = [] self._reset() def _reset(self): self._chunk_header = '' self.data = None self.line_count = 0 def connectionMade(self): self.msrp_transport._got_transport(self.transport) def lineReceived(self, line): if self.data: if len(line) == 0: terminator = '\r\n-------' + self.data.transaction_id continue_flags = [c+'\r\n' for c in '$#+'] self.term_buf = '' self.term_re = re.compile("^(.*)%s([$#+])\r\n(.*)$" % re.escape(terminator), re.DOTALL) self.term_substrings = [terminator[:i] for i in xrange(1, len(terminator)+1)] + [terminator+cont[:i] for cont in continue_flags for i in xrange(1, len(cont))] self.term_substrings.reverse() self.msrp_transport.logger.received_new_chunk(self._chunk_header+self.delimiter, self.msrp_transport, chunk=self.data) self.msrp_transport._data_start(self.data) self.setRawMode() else: match = self.term_re.match(line) if match: continuation = match.group(1) self.msrp_transport.logger.received_new_chunk(self._chunk_header, self.msrp_transport, chunk=self.data) self.msrp_transport.logger.received_chunk_end(line+self.delimiter, self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_start(self.data) self.msrp_transport._data_end(continuation) self._reset() # This line is only here because it allows the subclass MSRPPRotocol_withLogging # to know that the packet ended. In need of redesign. -Luci self.setLineMode('') else: self._chunk_header += line+self.delimiter self.line_count += 1 if self.line_count > self.MAX_LINES: self.msrp_transport.logger.received_illegal_data(self._chunk_header, self.msrp_transport) self._reset() return try: name, value = line.split(': ', 1) except ValueError: return # let this pass silently, we'll just not read this line else: self.data.add_header(MSRPHeader(name, value)) else: # we received a new message try: msrp, transaction_id, rest = line.split(" ", 2) except ValueError: self.msrp_transport.logger.received_illegal_data(line+self.delimiter, self.msrp_transport) return # drop connection? if msrp != "MSRP": self.msrp_transport.logger.received_illegal_data(line+self.delimiter, self.msrp_transport) return # drop connection? method, code, comment = None, None, None rest_sp = rest.split(" ", 1) try: if len(rest_sp[0]) != 3: raise ValueError code = int(rest_sp[0]) except ValueError: # we have a request method = rest_sp[0] else: # we have a response if len(rest_sp) > 1: comment = rest_sp[1] self.data = MSRPData(transaction_id, method, code, comment) self.term_re = re.compile("^-------%s([$#+])$" % re.escape(transaction_id)) self._chunk_header = line+self.delimiter def lineLengthExceeded(self, line): self._reset() def rawDataReceived(self, data): data = self.term_buf + data match = self.term_re.match(data) if match: # we got the last data for this message contents, continuation, extra = match.groups() if contents: self.msrp_transport.logger.received_chunk_data(contents, self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_write(contents, final=True) self.msrp_transport.logger.received_chunk_end('\r\n-------%s%s\r\n' % (self.data.transaction_id, continuation), self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_end(continuation) self._reset() self.setLineMode(extra) else: for term in self.term_substrings: if data.endswith(term): self.term_buf = data[-len(term):] data = data[:-len(term)] break else: self.term_buf = '' self.msrp_transport.logger.received_chunk_data(data, self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_write(data, final=False) def connectionLost(self, reason=connectionDone): self.msrp_transport._connectionLost(reason) _re_uri = re.compile("^(?P.*?)://(((?P.*?)@)?(?P.*?)(:(?P[0-9]+?))?)(/(?P.*?))?;(?P.*?)(;(?P.*))?$") def parse_uri(uri_str): match = _re_uri.match(uri_str) if match is None: raise ParsingError("Cannot parse URI") uri_params = match.groupdict() if uri_params["port"] is not None: uri_params["port"] = int(uri_params["port"]) if uri_params["parameters"] is not None: try: uri_params["parameters"] = dict(param.split("=") for param in uri_params["parameters"].split(";")) except ValueError: raise ParsingError("Cannot parse URI parameters") scheme = uri_params.pop("scheme") if scheme == "msrp": uri_params["use_tls"] = False elif scheme == "msrps": uri_params["use_tls"] = True else: raise ParsingError("Invalid scheme user in URI: %s" % scheme) if uri_params["transport"] != "tcp": raise ParsingError('Invalid transport in URI, only "tcp" is accepted: %s' % uri_params["transport"]) return URI(**uri_params) class ConnectInfo(object): host = None use_tls = True port = 2855 def __init__(self, host=None, use_tls=None, port=None, credentials=None): if host is not None: self.host = host if use_tls is not None: self.use_tls = use_tls if port is not None: self.port = port self.credentials = credentials if self.use_tls and self.credentials is None: from gnutls.interfaces.twisted import X509Credentials self.credentials = X509Credentials(None, None) @property def scheme(self): if self.use_tls: return 'msrps' else: return 'msrp' # use TLS_URI and TCP_URI ? class URI(ConnectInfo): def __init__(self, host=None, use_tls=None, user=None, port=None, session_id=None, transport="tcp", parameters=None, credentials=None): ConnectInfo.__init__(self, host or host_module.default_ip, use_tls=use_tls, port=port, credentials=credentials) self.user = user if session_id is None: session_id = '%x' % random.getrandbits(80) self.session_id = session_id self.transport = transport if parameters is None: self.parameters = {} else: self.parameters = parameters def __repr__(self): params = [self.host, self.use_tls, self.user, self.port, self.session_id, self.transport, self.parameters] defaults = [False, None, None, None, 'tcp', {}] while defaults and params[-1]==defaults[-1]: del params[-1] del defaults[-1] return '%s(%s)' % (self.__class__.__name__, ', '.join(`x` for x in params)) def __str__(self): uri_str = [] if self.use_tls: uri_str.append("msrps://") else: uri_str.append("msrp://") if self.user: uri_str.extend([self.user, "@"]) uri_str.append(self.host) if self.port: uri_str.extend([":", str(self.port)]) if self.session_id: uri_str.extend(["/", self.session_id]) uri_str.extend([";", self.transport]) for key, value in self.parameters.iteritems(): uri_str.extend([";", key, "=", value]) return "".join(uri_str) def __eq__(self, other): """MSRP URI comparison according to section 6.1 of RFC 4975""" if self is other: return True try: if self.use_tls != other.use_tls: return False if self.host.lower() != other.host.lower(): return False if self.port != other.port: return False if self.session_id != other.session_id: return False if self.transport.lower() != other.transport.lower(): return False except AttributeError: return False return True def __ne__(self, other): return not self == other def __hash__(self): return hash((self.use_tls, self.host.lower(), self.port, self.session_id, self.transport.lower())) diff --git a/msrplib/session.py b/msrplib/session.py index 8808371..ba4c8a7 100644 --- a/msrplib/session.py +++ b/msrplib/session.py @@ -1,348 +1,350 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details # -import random import traceback from time import time from twisted.internet.error import ConnectionClosed, ConnectionDone from twisted.python.failure import Failure from gnutls.errors import GNUTLSError from eventlib import api, coros, proc from eventlib.twistedutil.protocol import ValueQueue from msrplib import protocol, MSRPError from msrplib.transport import make_report, make_response, MSRPTransactionError -from msrplib.protocol import ContentTypeHeader, MSRPHeader + ConnectionClosedErrors = (ConnectionClosed, GNUTLSError) + class MSRPSessionError(MSRPError): pass + class MSRPBadContentType(MSRPTransactionError): code = 415 comment = 'Unsupported media type' -class LocalResponse(MSRPTransactionError): +class LocalResponse(MSRPTransactionError): def __repr__(self): return '' % (self.code, self.comment) + Response200OK = LocalResponse("OK", 200) Response408Timeout = LocalResponse("Timed out while waiting for transaction response", 408) def contains_mime_type(mimetypelist, mimetype): """Return True if mimetypelist contains mimetype. mimietypelist either contains the complete mime types, such as 'text/plain', or simple patterns, like 'text/*', or simply '*'. """ mimetype = mimetype.lower().partition(';')[0] for pattern in mimetypelist: pattern = pattern.lower() if pattern == '*': return True if pattern == mimetype: return True if pattern.endswith('/*') and mimetype.startswith(pattern[:-1]): return True return False class OutgoingChunk(object): __slots__ = ('chunk', 'response_callback') def __init__(self, chunk, response_callback=None): self.chunk = chunk self.response_callback = response_callback class MSRPSession(object): RESPONSE_TIMEOUT = 30 SHUTDOWN_TIMEOUT = 1 KEEPALIVE_INTERVAL = 60 def __init__(self, msrptransport, accept_types=['*'], on_incoming_cb=None, automatic_reports=True): self.msrp = msrptransport self.accept_types = accept_types self.automatic_reports = automatic_reports if not callable(on_incoming_cb): raise TypeError('on_incoming_cb must be callable: %r' % on_incoming_cb) self._on_incoming_cb = on_incoming_cb self.expected_responses = {} self.outgoing = coros.queue() self.reader_job = proc.spawn(self._reader) self.writer_job = proc.spawn(self._writer) self.state = 'CONNECTED' # -> 'FLUSHING' -> 'CLOSING' -> 'DONE' # in FLUSHING writer sends only while there's something in the outgoing queue # then it exits and sets state to 'CLOSING' which makes reader only pay attention # to responses and success reports. (XXX it could now discard incoming data chunks # with direct write() since writer is dead) self.reader_job.link(self.writer_job) self.last_expected_response = 0 self.keepalive_proc = proc.spawn(self._keepalive) def _get_logger(self): return self.msrp.logger def _set_logger(self, logger): self.msrp.logger = logger logger = property(_get_logger, _set_logger) def set_state(self, state): self.logger.debug('%s (was %s)' % (state, self.state)) self.state = state @property def connected(self): return self.state=='CONNECTED' def shutdown(self, wait=True): """Send the messages already in queue then close the connection""" self.set_state('FLUSHING') self.keepalive_proc.kill() self.keepalive_proc = None self.outgoing.send(None) if wait: self.writer_job.wait(None, None) self.reader_job.wait(None, None) def _keepalive(self): while True: api.sleep(self.KEEPALIVE_INTERVAL) if not self.connected: return try: chunk = self.msrp.make_send_request() - chunk.add_header(MSRPHeader('Keep-Alive', 'yes')) + chunk.add_header(protocol.MSRPHeader('Keep-Alive', 'yes')) self.deliver_chunk(chunk) except MSRPTransactionError, e: if e.code == 408: self.msrp.loseConnection(wait=False) self.set_state('CLOSING') return def _handle_incoming_response(self, chunk): try: response_cb, timer = self.expected_responses.pop(chunk.transaction_id) except KeyError: pass else: if timer is not None: timer.cancel() response_cb(chunk) def _check_incoming_SEND(self, chunk): error = self.msrp.check_incoming_SEND_chunk(chunk) if error is not None: return error if chunk.data: - if chunk.headers.get('Content-Type') is None: - return MSRPBadContentType('Content-type header missing') - if not contains_mime_type(self.accept_types, chunk.headers['Content-Type'].decoded): + if chunk.content_type is None: + return MSRPBadContentType('Content-Type header missing') + if not contains_mime_type(self.accept_types, chunk.content_type): return MSRPBadContentType def _handle_incoming_SEND(self, chunk): error = self._check_incoming_SEND(chunk) if error is None: code, comment = 200, 'OK' else: code, comment = error.code, error.comment response = make_response(chunk, code, comment) if response is not None: self.outgoing.send(OutgoingChunk(response)) if code == 200: self._on_incoming_cb(chunk) if self.automatic_reports: report = make_report(chunk, 200, 'OK') if report is not None: self.outgoing.send(OutgoingChunk(report)) def _handle_incoming_REPORT(self, chunk): self._on_incoming_cb(chunk) def _handle_incoming_NICKNAME(self, chunk): if 'Use-Nickname' not in chunk.headers or 'Success-Report' in chunk.headers or 'Failure-Report' in chunk.headers: response = make_response(chunk, 400, 'Bad request') self.outgoing.send(OutgoingChunk(response)) return self._on_incoming_cb(chunk) def _reader(self): """Wait forever for new chunks. Notify the user about the good ones through self._on_incoming_cb. If a response to a previously sent chunk is received, pop the corresponding response_cb from self.expected_responses and send the response there. """ error = Failure(ConnectionDone()) try: self.writer_job.link(self.reader_job) try: while self.state in ['CONNECTED', 'FLUSHING']: chunk = self.msrp.read_chunk() if chunk.method is None: # response self._handle_incoming_response(chunk) else: method = getattr(self, '_handle_incoming_%s' % chunk.method, None) if method is not None: method(chunk) else: response = make_response(chunk, 501, 'Method unknown') self.outgoing.send(OutgoingChunk(response)) except proc.LinkedExited: # writer has exited pass finally: self.writer_job.unlink(self.reader_job) self.writer_job.kill() self.logger.debug('reader: expecting responses only') delay = time() - self.last_expected_response if delay>=0 and self.expected_responses: # continue read the responses until the last timeout expires with api.timeout(delay, None): while self.expected_responses: chunk = self.msrp.read_chunk() if chunk.method is None: self._handle_incoming_response(chunk) else: self.logger.debug('dropping incoming %r' % chunk) # read whatever left in the queue with api.timeout(0, None): while self.msrp._queue: chunk = self.msrp.read_chunk() if chunk.method is None: self._handle_incoming_response(chunk) else: self.logger.debug('dropping incoming %r' % chunk) self.logger.debug('reader: done') except ConnectionClosedErrors, ex: self.logger.debug('reader: exiting because of %r' % ex) error=Failure(ex) except Exception: self.logger.err('reader: captured unhandled exception\n%r' % traceback.format_exc()) error=Failure() raise finally: self._on_incoming_cb(error=error) self.msrp.loseConnection(wait=False) self.set_state('DONE') def _writer(self): try: while self.state=='CONNECTED' or (self.state=='FLUSHING' and self.outgoing): item = self.outgoing.wait() if item is None: break self._write_chunk(item.chunk, item.response_callback) except ConnectionClosedErrors + (proc.LinkedExited, proc.ProcExit), e: self.logger.debug('writer: exiting because of %r' % e) except: self.logger.err('writer: captured unhandled exception:\n%s' % traceback.format_exc()) raise finally: self.msrp.loseConnection(wait=False) self.set_state('CLOSING') def _write_chunk(self, chunk, response_cb=None): assert chunk.transaction_id not in self.expected_responses, "MSRP transaction %r is already in progress" % chunk.transaction_id self.msrp.write_chunk(chunk) if response_cb is not None: timer = api.get_hub().schedule_call_global(self.RESPONSE_TIMEOUT, self._response_timeout, chunk.transaction_id, Response408Timeout) self.expected_responses[chunk.transaction_id] = (response_cb, timer) self.last_expected_response = time() + self.RESPONSE_TIMEOUT def _response_timeout(self, id, timeout_error): response_cb, timer = self.expected_responses.pop(id, (None, None)) if response_cb is not None: response_cb(timeout_error) if timer is not None: timer.cancel() def send_chunk(self, chunk, response_cb=None): """Send `chunk'. Report the result via `response_cb'. When `response_cb' argument is present, it will be used to report the transaction response to the caller. When a response is received or generated locally, `response_cb' is called with one argument. The function must do something quickly and must not block, because otherwise it would the reader greenlet. If no response was received after RESPONSE_TIMEOUT seconds, * 408 response is generated if Failure-Report was 'yes' or absent * 200 response is generated if Failure-Report was 'partial' or 'no' Note that it's rather wasteful to provide `response_cb' argument other than None for chunks with Failure-Report='no' since it will always fire 30 seconds later with 200 result (unless the other party is broken and ignores Failure-Report header) If sending is impossible raise MSRPSessionError. """ assert chunk.transaction_id not in self.expected_responses, "MSRP transaction %r is already in progress" % chunk.transaction_id if response_cb is not None and not callable(response_cb): raise TypeError('response_cb must be callable: %r' % (response_cb, )) if self.state != 'CONNECTED': raise MSRPSessionError('Cannot send chunk because MSRPSession is %s' % self.state) if self.msrp._disconnected_event.ready(): raise MSRPSessionError(str(self.msrp._disconnected_event.wait())) self.outgoing.send(OutgoingChunk(chunk, response_cb)) def send_report(self, chunk, code, reason): if chunk.method != 'SEND': raise ValueError('reports may only be sent for SEND chunks') report = make_report(chunk, code, reason) if report is not None: self.send_chunk(report) def deliver_chunk(self, chunk, event=None): """Send chunk, wait for the transaction response (if Failure-Report header is not 'no'). Return the transaction response if it's a success, raise MSRPTransactionError if it's not. If chunk's Failure-Report is 'no', return None immediately. """ if chunk.failure_report!='no' and event is None: event = coros.event() self.send_chunk(chunk, event.send) if event is not None: response = event.wait() if isinstance(response, Exception): raise response elif 200 <= response.code <= 299: return response raise MSRPTransactionError(comment=response.comment, code=response.code) def make_message(self, msg, content_type, message_id=None): chunk = self.msrp.make_send_request(data=msg, message_id=message_id) chunk.add_header(protocol.ContentTypeHeader(content_type)) return chunk def send_message(self, msg, content_type): chunk = self.make_message(msg, content_type) self.send_chunk(chunk) return chunk def deliver_message(self, msg, content_type): chunk = self.make_message(msg, content_type) self.deliver_chunk(chunk) return chunk class GreenMSRPSession(MSRPSession): def __init__(self, msrptransport, accept_types=['*']): MSRPSession.__init__(self, msrptransport, accept_types, on_incoming_cb=self._incoming_cb) self.incoming = ValueQueue() def receive_chunk(self): return self.incoming.wait() def _incoming_cb(self, value=None, error=None): if error is not None: self.incoming.send_exception(error.type, error.value, error.tb) else: self.incoming.send(value) # TODO: # 413 - requires special action both in reader and in writer # continuation: # # All MSRP endpoints MUST be able to receive the multipart/mixed [15] and multipart/alternative [15] media-types. diff --git a/msrplib/transport.py b/msrplib/transport.py index 6e498a8..7c590c4 100644 --- a/msrplib/transport.py +++ b/msrplib/transport.py @@ -1,346 +1,341 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details import random from application import log from twisted.internet.error import ConnectionDone from eventlib.twistedutil.protocol import GreenTransportBase from msrplib import protocol, MSRPError from msrplib.trafficlog import Logger log = log.get_logger('msrplib') class ChunkParseError(MSRPError): """Failed to parse incoming chunk""" class MSRPTransactionError(MSRPError): def __init__(self, comment=None, code=None): if comment is not None: self.comment = comment if code is not None: self.code = code if not hasattr(self, 'code'): raise TypeError("must provide 'code'") def __str__(self): if hasattr(self, 'comment'): return '%s %s' % (self.code, self.comment) else: return str(self.code) class MSRPBadRequest(MSRPTransactionError): code = 400 comment = 'Bad Request' def __str__(self): return 'Remote party sent bogus data' class MSRPNoSuchSessionError(MSRPTransactionError): code = 481 comment = 'No such session' data_start, data_end, data_write, data_final_write = xrange(4) class MSRPProtocol_withLogging(protocol.MSRPProtocol): _new_chunk = True def rawDataReceived(self, data): self.msrp_transport.logger.report_in(data, self.msrp_transport, self._new_chunk) protocol.MSRPProtocol.rawDataReceived(self, data) def lineReceived(self, line): self.msrp_transport.logger.report_in(line+self.delimiter, self.msrp_transport, self._new_chunk) self._new_chunk = False protocol.MSRPProtocol.lineReceived(self, line) def connectionLost(self, reason): msg = 'Closed connection to %s:%s' % (self.transport.getPeer().host, self.transport.getPeer().port) if not isinstance(reason.value, ConnectionDone): msg += ' (%s)' % reason.getErrorMessage() self.msrp_transport.logger.info(msg) protocol.MSRPProtocol.connectionLost(self, reason) def setLineMode(self, extra): self._new_chunk = True self.msrp_transport.logger.report_in('', self.msrp_transport, packet_done=True) return protocol.MSRPProtocol.setLineMode(self, extra) def make_report(chunk, code, comment): if chunk.success_report == 'yes' or (chunk.failure_report in ('yes', 'partial') and code != 200): report = protocol.MSRPData(transaction_id='%x' % random.getrandbits(64), method='REPORT') - report.add_header(protocol.ToPathHeader(chunk.headers['From-Path'].decoded)) - report.add_header(protocol.FromPathHeader([chunk.headers['To-Path'].decoded[0]])) - report.add_header(protocol.StatusHeader('000 %d %s' % (code, comment))) + report.add_header(protocol.ToPathHeader(chunk.from_path)) + report.add_header(protocol.FromPathHeader([chunk.to_path[0]])) + report.add_header(protocol.StatusHeader(protocol.Status(code, comment))) report.add_header(protocol.MessageIDHeader(chunk.message_id)) - byterange = chunk.headers.get('Byte-Range') - if byterange is None: + if chunk.byte_range is None: start = 1 total = chunk.size else: - start, end, total = byterange.decoded - report.add_header(protocol.ByteRangeHeader((start, start+chunk.size-1, total))) + start, end, total = chunk.byte_range + report.add_header(protocol.ByteRangeHeader(protocol.ByteRange(start, start+chunk.size-1, total))) return report else: return None def make_response(chunk, code, comment): """Construct a response to a request as described in RFC4975 Section 7.2. If the response is not needed, return None. If a required header missing, raise ChunkParseError. """ - if chunk.failure_report=='no': + if chunk.failure_report == 'no': return - if chunk.failure_report=='partial' and code==200: + if chunk.failure_report == 'partial' and code == 200: return - if not chunk.headers.get('To-Path'): + if chunk.to_path is None: raise ChunkParseError('missing To-Path header: %r' % chunk) - if not chunk.headers.get('From-Path'): + if chunk.from_path is None: raise ChunkParseError('missing From-Path header: %r' % chunk) - if chunk.method=='SEND': - to_path = [chunk.headers['From-Path'].decoded[0]] + if chunk.method == 'SEND': + to_path = [chunk.from_path[0]] else: - to_path = chunk.headers['From-Path'].decoded - from_path = [chunk.headers['To-Path'].decoded[0]] + to_path = chunk.from_path + from_path = [chunk.to_path[0]] response = protocol.MSRPData(chunk.transaction_id, code=code, comment=comment) response.add_header(protocol.ToPathHeader(to_path)) response.add_header(protocol.FromPathHeader(from_path)) return response class MSRPTransport(GreenTransportBase): protocol_class = MSRPProtocol_withLogging def __init__(self, local_uri, logger, use_sessmatch=False): GreenTransportBase.__init__(self) if local_uri is not None and not isinstance(local_uri, protocol.URI): raise TypeError('Not MSRP URI instance: %r' % (local_uri, )) # The following members define To-Path and From-Path headers as following: # * Outgoing request: # From-Path: local_uri # To-Path: local_path + remote_path + [remote_uri] # * Incoming request: # From-Path: remote_path + remote_uri # To-Path: remote_path + local_path + [local_uri] # XXX self.local_uri = local_uri if logger is None: logger = Logger() self.logger = logger self.local_path = [] self.remote_uri = None self.remote_path = [] self.use_sessmatch = use_sessmatch self._msrpdata = None def next_host(self): if self.local_path: return self.local_path[0] return self.full_remote_path[0] def set_local_path(self, lst): self.local_path = lst @property def full_local_path(self): "suitable to put into INVITE" return self.local_path + [self.local_uri] @property def full_remote_path(self): return self.remote_path + [self.remote_uri] def make_request(self, method): transaction_id = '%x' % random.getrandbits(64) chunk = protocol.MSRPData(transaction_id=transaction_id, method=method) chunk.add_header(protocol.ToPathHeader(self.local_path + self.remote_path + [self.remote_uri])) chunk.add_header(protocol.FromPathHeader([self.local_uri])) return chunk def make_send_request(self, message_id=None, data='', start=1, end=None, length=None): chunk = self.make_request('SEND') if end is None: end = start - 1 + len(data) if length is None: length = start - 1 + len(data) if end == length != '*': contflag = '$' else: contflag = '+' - chunk.add_header(protocol.ByteRangeHeader((start, end if length <= 2048 else None, length))) + chunk.add_header(protocol.ByteRangeHeader(protocol.ByteRange(start, end if length <= 2048 else None, length))) if message_id is None: message_id = '%x' % random.getrandbits(64) chunk.add_header(protocol.MessageIDHeader(message_id)) chunk.data = data chunk.contflag = contflag return chunk def _data_start(self, data): self._queue.send((data_start, data)) def _data_end(self, continuation): self._queue.send((data_end, continuation)) def _data_write(self, contents, final): if final: self._queue.send((data_final_write, contents)) else: self._queue.send((data_write, contents)) def write(self, bytes, wait=True): """Write `bytes' to the socket. If `wait' is true, wait for an operation to complete""" self.logger.report_out(bytes, self.transport) return GreenTransportBase.write(self, bytes, wait) def write_chunk(self, chunk, wait=True): trailer = chunk.encode_start() footer = chunk.encode_end(chunk.contflag) self.write(trailer+chunk.data+footer, wait=wait) self.logger.sent_new_chunk(trailer, self, chunk=chunk) if chunk.data: self.logger.sent_chunk_data(chunk.data, self, chunk.transaction_id) self.logger.sent_chunk_end(footer, self, chunk.transaction_id) def read_chunk(self, max_size=1024*1024*4): """Wait for a new chunk and return it. If there was an error, close the connection and raise ChunkParseError. In case of unintelligible input, lose the connection and return None. When the connection is closed, raise the reason of the closure (e.g. ConnectionDone). """ assert max_size > 0 if self._msrpdata is None: func, msrpdata = self._wait() if func!=data_start: self.logger.debug('Bad data: %r %r' % (func, msrpdata)) self.loseConnection() raise ChunkParseError else: msrpdata = self._msrpdata data = msrpdata.data func, param = self._wait() while func == data_write: data += param if len(data) > max_size: self.logger.debug('Chunk is too big (max_size=%d bytes)' % max_size) self.loseConnection() raise ChunkParseError func, param = self._wait() if func == data_final_write: data += param func, param = self._wait() if func != data_end: self.logger.debug('Bad data: %r %s' % (func, repr(param)[:100])) self.loseConnection() raise ChunkParseError if param not in "$+#": self.logger.debug('Bad data: %r %s' % (func, repr(param)[:100])) self.loseConnection() raise ChunkParseError msrpdata.data = data msrpdata.contflag = param self._msrpdata = None self.logger.debug('read_chunk -> %r' % (msrpdata, )) return msrpdata def _set_full_remote_path(self, full_remote_path): "as received in response to INVITE" if not all(isinstance(x, protocol.URI) for x in full_remote_path): raise TypeError('Not all elements are MSRP URI: %r' % full_remote_path) self.remote_uri = full_remote_path[-1] self.remote_path = full_remote_path[:-1] def bind(self, full_remote_path): self._set_full_remote_path(full_remote_path) chunk = self.make_send_request() self.write_chunk(chunk) # With some ACM implementations both parties may think they are active, # so they will both send an empty SEND request. -Saul while True: chunk = self.read_chunk() if chunk.code is None: # This was not a response, it was a request if chunk.method == 'SEND' and not chunk.data: self.write_response(chunk, 200, 'OK') else: self.loseConnection(wait=False) raise MSRPNoSuchSessionError('Chunk received while binding session: %s' % chunk) elif chunk.code != 200: self.loseConnection(wait=False) raise MSRPNoSuchSessionError('Cannot bind session: %s' % chunk) else: break def write_response(self, chunk, code, comment, wait=True): """Generate and write the response, lose the connection in case of error""" try: response = make_response(chunk, code, comment) except ChunkParseError, ex: log.error('Failed to generate a response: %s' % ex) self.loseConnection(wait=False) raise except Exception: log.exception('Failed to generate a response') self.loseConnection(wait=False) raise else: if response is not None: self.write_chunk(response, wait=wait) def accept_binding(self, full_remote_path): self._set_full_remote_path(full_remote_path) chunk = self.read_chunk() error = self.check_incoming_SEND_chunk(chunk) if error is None: code, comment = 200, 'OK' else: code, comment = error.code, error.comment self.write_response(chunk, code, comment) if 'Content-Type' in chunk.headers or chunk.size>0: # deliver chunk to read_chunk data = chunk.data chunk.data = '' self._data_start(chunk) self._data_write(data, final=True) self._data_end(chunk.contflag) def check_incoming_SEND_chunk(self, chunk): """Check the 'To-Path' and 'From-Path' of the incoming SEND chunk. Return None is the paths are valid for this connection. If an error is detected and MSRPError is created and returned. """ - assert chunk.method=='SEND', repr(chunk) - try: - ToPath = chunk.headers['To-Path'] - except KeyError: + assert chunk.method == 'SEND', repr(chunk) + if chunk.to_path is None: return MSRPBadRequest('To-Path header missing') - try: - FromPath = chunk.headers['From-Path'] - except KeyError: + if chunk.from_path is None: return MSRPBadRequest('From-Path header missing') - ToPath = list(ToPath.decoded) - FromPath = list(FromPath.decoded) - ExpectedTo = [self.local_uri] - ExpectedFrom = self.local_path + self.remote_path + [self.remote_uri] + to_path = list(chunk.to_path) + from_path = list(chunk.from_path) + expected_to = [self.local_uri] + expected_from = self.local_path + self.remote_path + [self.remote_uri] # Match only session ID when use_sessmatch is set (http://tools.ietf.org/html/draft-ietf-simple-msrp-sessmatch-10) if self.use_sessmatch: - if ToPath[0].session_id != ExpectedTo[0].session_id: - log.error('To-Path: expected session_id %s, got %s' % (ExpectedTo[0].session_id, ToPath[0].session_id)) + if to_path[0].session_id != expected_to[0].session_id: + log.error('To-Path: expected session_id %s, got %s' % (expected_to[0].session_id, to_path[0].session_id)) return MSRPNoSuchSessionError('Invalid To-Path') - if FromPath[0].session_id != ExpectedFrom[0].session_id: - log.error('From-Path: expected session_id %s, got %s' % (ExpectedFrom[0].session_id, FromPath[0].session_id)) + if from_path[0].session_id != expected_from[0].session_id: + log.error('From-Path: expected session_id %s, got %s' % (expected_from[0].session_id, from_path[0].session_id)) return MSRPNoSuchSessionError('Invalid From-Path') else: - if ToPath != ExpectedTo: - log.error('To-Path: expected %r, got %r' % (ExpectedTo, ToPath)) + if to_path != expected_to: + log.error('To-Path: expected %r, got %r' % (expected_to, to_path)) return MSRPNoSuchSessionError('Invalid To-Path') - if FromPath != ExpectedFrom: - log.error('From-Path: expected %r, got %r' % (ExpectedFrom, FromPath)) + if from_path != expected_from: + log.error('From-Path: expected %r, got %r' % (expected_from, from_path)) return MSRPNoSuchSessionError('Invalid From-Path')