diff --git a/msrplib/protocol.py b/msrplib/protocol.py index f6d2038..9d84879 100644 --- a/msrplib/protocol.py +++ b/msrplib/protocol.py @@ -1,753 +1,809 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details import random import re from collections import deque, namedtuple from application.system import host as host_module from twisted.internet.protocol import connectionDone from twisted.protocols.basic import LineReceiver from msrplib import MSRPError class ParsingError(MSRPError): pass class HeaderParsingError(ParsingError): def __init__(self, header, reason): self.header = header ParsingError.__init__(self, 'Error parsing {} header: {}'.format(header, reason)) # Header value data types (for decoded values) # ByteRange = namedtuple('ByteRange', ['start', 'end', 'total']) Status = namedtuple('Status', ['code', 'comment']) ContentDisposition = namedtuple('ContentDisposition', ['disposition', 'parameters']) # Header value types (describe how to encode/decode the value) # class SimpleHeaderType(object): data_type = object @staticmethod def decode(encoded): return encoded @staticmethod def encode(decoded): return decoded class UTF8HeaderType(object): data_type = unicode @staticmethod def decode(encoded): return encoded.decode('utf-8') @staticmethod def encode(decoded): return decoded.encode('utf-8') class URIHeaderType(object): data_type = deque @staticmethod def decode(encoded): return deque(parse_uri(uri) for uri in encoded.split(' ')) @staticmethod def encode(decoded): return ' '.join(str(uri) for uri in decoded) class IntegerHeaderType(object): data_type = int @staticmethod def decode(encoded): return int(encoded) @staticmethod def encode(decoded): return str(decoded) class LimitedChoiceHeaderType(SimpleHeaderType): allowed_values = frozenset() @classmethod def decode(cls, encoded): if encoded not in cls.allowed_values: raise ValueError('Invalid value: {!r}'.format(encoded)) return encoded class SuccessReportHeaderType(LimitedChoiceHeaderType): allowed_values = frozenset({'yes', 'no'}) class FailureReportHeaderType(LimitedChoiceHeaderType): allowed_values = frozenset({'yes', 'no', 'partial'}) class ByteRangeHeaderType(object): data_type = ByteRange regex = re.compile(r'(?P\d+)-(?P\*|\d+)/(?P\*|\d+)') @classmethod def decode(cls, encoded): match = cls.regex.match(encoded) if match is None: raise ValueError('Invalid byte range value: {!r}'.format(encoded)) start, end, total = match.groups() start = int(start) end = int(end) if end != '*' else None total = int(total) if total != '*' else None return ByteRange(start, end, total) @staticmethod def encode(decoded): start, end, total = decoded return '{}-{}/{}'.format(start, end or '*', total or '*') class StatusHeaderType(object): data_type = Status @staticmethod def decode(encoded): namespace, sep, rest = encoded.partition(' ') if namespace != '000' or sep != ' ': raise ValueError('Invalid status value: {!r}'.format(encoded)) code, _, comment = rest.partition(' ') if not code.isdigit() or len(code) != 3: raise ValueError('Invalid status code: {!r}'.format(code)) return Status(int(code), comment or None) @staticmethod def encode(decoded): code, comment = decoded if comment is None: return '000 {:03d}'.format(code) else: return '000 {:03d} {}'.format(code, comment) class ContentDispositionHeaderType(object): data_type = ContentDisposition regex = re.compile(r'(\w+)=("[^"]+"|[^";]+)') @classmethod def decode(cls, encoded): disposition, _, parameters = encoded.partition(';') if not disposition: raise ValueError('Invalid content disposition: {!r}'.format(encoded)) return ContentDisposition(disposition, {name: value.strip('"') for name, value in cls.regex.findall(parameters)}) @staticmethod def encode(decoded): disposition, parameters = decoded return '; '.join([disposition] + ['{}="{}"'.format(name, value) for name, value in parameters.iteritems()]) class ParameterListHeaderType(object): data_type = dict regex = re.compile(r'(\w+)=("[^"]+"|[^",]+)') @classmethod def decode(cls, encoded): return {name: value.strip('"') for name, value in cls.regex.findall(encoded)} @staticmethod def encode(decoded): return ', '.join('{}="{}"'.format(name, value) for name, value in decoded.iteritems()) class DigestHeaderType(ParameterListHeaderType): @classmethod def decode(cls, encoded): algorithm, sep, parameters = encoded.partition(' ') if algorithm != 'Digest' or sep != ' ': raise ValueError('Invalid Digest header value') return super(DigestHeaderType, cls).decode(parameters) @staticmethod def encode(decoded): return 'Digest ' + super(DigestHeaderType, DigestHeaderType).encode(decoded) # Header classes # class MSRPHeaderMeta(type): __classmap__ = {} name = None def __init__(cls, name, bases, dictionary): type.__init__(cls, name, bases, dictionary) if cls.name is not None: cls.__classmap__[cls.name] = cls def __call__(cls, *args, **kw): if cls.name is not None: return super(MSRPHeaderMeta, cls).__call__(*args, **kw) # specialized class, instantiated directly. else: return cls._instantiate_specialized_class(*args, **kw) # non-specialized class, instantiated as a more specialized class if available. def _instantiate_specialized_class(cls, name, value): if name in cls.__classmap__: return super(MSRPHeaderMeta, cls.__classmap__[name]).__call__(value) else: return super(MSRPHeaderMeta, cls).__call__(name, value) class MSRPHeader(object): __metaclass__ = MSRPHeaderMeta name = None type = SimpleHeaderType def __init__(self, name, value): self.name = name if isinstance(value, str): self.encoded = value else: self.decoded = value def __eq__(self, other): if isinstance(other, MSRPHeader): return self.name == other.name and self.decoded == other.decoded return NotImplemented def __ne__(self, other): return not self == other @property def encoded(self): if self._encoded is None: self._encoded = self.type.encode(self._decoded) return self._encoded @encoded.setter def encoded(self, encoded): self._encoded = encoded self._decoded = None @property def decoded(self): if self._decoded is None: try: self._decoded = self.type.decode(self._encoded) except Exception as e: raise HeaderParsingError(self.name, str(e)) return self._decoded @decoded.setter def decoded(self, decoded): if not isinstance(decoded, self.type.data_type): try: # noinspection PyArgumentList decoded = self.type.data_type(decoded) except Exception: raise TypeError('value must be an instance of {}'.format(self.type.data_type.__name__)) self._decoded = decoded self._encoded = None class MSRPNamedHeader(MSRPHeader): def __init__(self, value): MSRPHeader.__init__(self, self.name, value) class ToPathHeader(MSRPNamedHeader): name = 'To-Path' type = URIHeaderType class FromPathHeader(MSRPNamedHeader): name = 'From-Path' type = URIHeaderType class MessageIDHeader(MSRPNamedHeader): name = 'Message-ID' type = SimpleHeaderType class SuccessReportHeader(MSRPNamedHeader): name = 'Success-Report' type = SuccessReportHeaderType class FailureReportHeader(MSRPNamedHeader): name = 'Failure-Report' type = FailureReportHeaderType class ByteRangeHeader(MSRPNamedHeader): name = 'Byte-Range' type = ByteRangeHeaderType @property def start(self): return self.decoded.start @property def end(self): return self.decoded.end @property def total(self): return self.decoded.total class StatusHeader(MSRPNamedHeader): name = 'Status' type = StatusHeaderType @property def code(self): return self.decoded.code @property def comment(self): return self.decoded.comment class ExpiresHeader(MSRPNamedHeader): name = 'Expires' type = IntegerHeaderType class MinExpiresHeader(MSRPNamedHeader): name = 'Min-Expires' type = IntegerHeaderType class MaxExpiresHeader(MSRPNamedHeader): name = 'Max-Expires' type = IntegerHeaderType class UsePathHeader(MSRPNamedHeader): name = 'Use-Path' type = URIHeaderType class WWWAuthenticateHeader(MSRPNamedHeader): name = 'WWW-Authenticate' type = DigestHeaderType class AuthorizationHeader(MSRPNamedHeader): name = 'Authorization' type = DigestHeaderType class AuthenticationInfoHeader(MSRPNamedHeader): name = 'Authentication-Info' type = ParameterListHeaderType class ContentTypeHeader(MSRPNamedHeader): name = 'Content-Type' type = SimpleHeaderType class ContentIDHeader(MSRPNamedHeader): name = 'Content-ID' type = SimpleHeaderType class ContentDescriptionHeader(MSRPNamedHeader): name = 'Content-Description' type = SimpleHeaderType class ContentDispositionHeader(MSRPNamedHeader): name = 'Content-Disposition' type = ContentDispositionHeaderType class UseNicknameHeader(MSRPNamedHeader): name = 'Use-Nickname' type = UTF8HeaderType class HeaderOrderMapping(dict): __levels__ = { 0: ['To-Path'], 1: ['From-Path'], 2: ['Status', 'Message-ID', 'Byte-Range', 'Success-Report', 'Failure-Report'] + ['Authorization', 'Authentication-Info', 'WWW-Authenticate', 'Expires', 'Min-Expires', 'Max-Expires', 'Use-Path', 'Use-Nickname'], 3: ['Content-ID', 'Content-Description', 'Content-Disposition'], 4: ['Content-Type'] } def __init__(self): super(HeaderOrderMapping, self).__init__({name: level for level, name_list in self.__levels__.items() for name in name_list}) def __missing__(self, key): return 3 if key.startswith('Content-') else 2 sort_key = dict.__getitem__ class HeaderOrdering(object): name_map = HeaderOrderMapping() sort_key = name_map.sort_key class MissingHeader(object): decoded = None +class HeaderMapping(dict): + def __init__(self, *args, **kw): + super(HeaderMapping, self).__init__(*args, **kw) + self.__modified__ = True + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, super(HeaderMapping, self).__repr__()) + + def __setitem__(self, key, value): + super(HeaderMapping, self).__setitem__(key, value) + self.__modified__ = True + + def __delitem__(self, key): + super(HeaderMapping, self).__delitem__(key) + self.__modified__ = True + + def __copy__(self): + return self.__class__(self) + + def clear(self): + super(HeaderMapping, self).clear() + self.__modified__ = True + + def copy(self): + return self.__class__(self) + + def pop(self, *args): + result = super(HeaderMapping, self).pop(*args) + self.__modified__ = True + return result + + def popitem(self): + result = super(HeaderMapping, self).popitem() + self.__modified__ = True + return result + + def setdefault(self, *args): + result = super(HeaderMapping, self).setdefault(*args) + self.__modified__ = True + return result + + def update(self, *args, **kw): + super(HeaderMapping, self).update(*args, **kw) + self.__modified__ = True + + class MSRPData(object): + __immutable__ = frozenset({'method', 'code', 'comment', 'headers'}) # Immutable attributes (cannot be overwritten) + def __init__(self, transaction_id, method=None, code=None, comment=None, headers=None, data='', contflag='$'): if method is None and code is None: raise ValueError('either method or code must be specified') elif method is not None and code is not None: raise ValueError('method and code cannot be both specified') elif code is None and comment is not None: raise ValueError('comment should only be specified when code is specified') self.transaction_id = transaction_id self.method = method self.code = code self.comment = comment - self.headers = headers or {} + self.headers = HeaderMapping(headers or {}) self.data = data self.contflag = contflag if method is not None: self.first_line = 'MSRP {} {}'.format(transaction_id, method) elif comment is None: self.first_line = 'MSRP {} {:03d}'.format(transaction_id, code) else: self.first_line = 'MSRP {} {:03d} {}'.format(transaction_id, code, comment) + self.__modified__ = True def __setattr__(self, name, value): - if name in {'method', 'code', 'comment'} and name in self.__dict__: - raise AttributeError('Cannot overwrite attribute') - if name == 'transaction_id' and name in self.__dict__: - self.first_line = self.first_line.replace(self.transaction_id, value) + if name in self.__dict__: + if name in self.__immutable__: + raise AttributeError('Cannot overwrite attribute') + elif name == 'transaction_id': + self.first_line = self.first_line.replace(self.transaction_id, value) + self.__modified__ = True super(MSRPData, self).__setattr__(name, value) def __str__(self): # TODO: make __str__ == encode()? return self.first_line def __repr__(self): description = self.first_line for name in sorted(self.headers, key=HeaderOrdering.sort_key): description += ' {}={!r}'.format(name, self.headers[name].encoded) description += ' len={}'.format(self.size) return '<{} at {:#x} {} {}>'.format(self.__class__.__name__, id(self), description, self.contflag) def __eq__(self, other): if isinstance(other, MSRPData): return self.first_line == other.first_line and self.headers == other.headers and self.data == other.data and self.contflag == other.contflag return NotImplemented def __ne__(self, other): return not self == other def copy(self): return self.__class__(self.transaction_id, self.method, self.code, self.comment, self.headers.copy(), self.data, self.contflag) def add_header(self, header): self.headers[header.name] = header def verify_headers(self): if 'To-Path' not in self.headers: raise HeaderParsingError('To-Path', 'header is missing') if 'From-Path' not in self.headers: raise HeaderParsingError('From-Path', 'header is missing') for header in self.headers.itervalues(): _ = header.decoded @property def from_path(self): return self.headers.get('From-Path', MissingHeader).decoded @property def to_path(self): return self.headers.get('To-Path', MissingHeader).decoded @property def content_type(self): return self.headers.get('Content-Type', MissingHeader).decoded @property def message_id(self): return self.headers.get('Message-ID', MissingHeader).decoded @property def byte_range(self): return self.headers.get('Byte-Range', MissingHeader).decoded @property def status(self): return self.headers.get('Status', MissingHeader).decoded @property def failure_report(self): return self.headers.get('Failure-Report', MissingHeader).decoded or 'yes' @property def success_report(self): return self.headers.get('Success-Report', MissingHeader).decoded or 'no' @property def size(self): return len(self.data) - def encode_start(self): - lines = [self.first_line] + ['{}: {}'.format(name, self.headers[name].encoded) for name in sorted(self.headers, key=HeaderOrdering.sort_key)] - if 'Content-Type' in self.headers: - lines.append('\r\n') - return '\r\n'.join(lines) + @property + def encoded_header(self): + if self.__modified__ or self.headers.__modified__: + lines = [self.first_line] + ['{}: {}'.format(name, self.headers[name].encoded) for name in sorted(self.headers, key=HeaderOrdering.sort_key)] + if 'Content-Type' in self.headers: + lines.append('\r\n') + self.__dict__['encoded_header'] = '\r\n'.join(lines) + self.__modified__ = self.headers.__modified__ = False + return self.__dict__['encoded_header'] - def encode_end(self, continuation): - return '\r\n-------%s%s\r\n' % (self.transaction_id, continuation) + @property + def encoded_footer(self): + return '\r\n-------{}{}\r\n'.format(self.transaction_id, self.contflag) def encode(self): - return self.encode_start() + self.data + self.encode_end(self.contflag) + return self.encoded_header + self.data + self.encoded_footer class MSRPProtocol(LineReceiver): MAX_LENGTH = 16384 MAX_LINES = 64 def __init__(self, msrp_transport): self.msrp_transport = msrp_transport self.term_buf = '' self.term_re = None self.term_substrings = [] self._reset() def _reset(self): self._chunk_header = '' self.data = None self.line_count = 0 def connectionMade(self): self.msrp_transport._got_transport(self.transport) def lineReceived(self, line): if self.data: if len(line) == 0: terminator = '\r\n-------' + self.data.transaction_id continue_flags = [c+'\r\n' for c in '$#+'] self.term_buf = '' self.term_re = re.compile("^(.*)%s([$#+])\r\n(.*)$" % re.escape(terminator), re.DOTALL) self.term_substrings = [terminator[:i] for i in xrange(1, len(terminator)+1)] + [terminator+cont[:i] for cont in continue_flags for i in xrange(1, len(cont))] self.term_substrings.reverse() self.msrp_transport.logger.received_new_chunk(self._chunk_header+self.delimiter, self.msrp_transport, chunk=self.data) self.msrp_transport._data_start(self.data) self.setRawMode() else: match = self.term_re.match(line) if match: continuation = match.group(1) self.msrp_transport.logger.received_new_chunk(self._chunk_header, self.msrp_transport, chunk=self.data) self.msrp_transport.logger.received_chunk_end(line+self.delimiter, self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_start(self.data) self.msrp_transport._data_end(continuation) self._reset() # This line is only here because it allows the subclass MSRPPRotocol_withLogging # to know that the packet ended. In need of redesign. -Luci self.setLineMode('') else: self._chunk_header += line+self.delimiter self.line_count += 1 if self.line_count > self.MAX_LINES: self.msrp_transport.logger.received_illegal_data(self._chunk_header, self.msrp_transport) self._reset() return try: name, value = line.split(': ', 1) except ValueError: return # let this pass silently, we'll just not read this line else: self.data.add_header(MSRPHeader(name, value)) else: # we received a new message try: msrp, transaction_id, rest = line.split(" ", 2) except ValueError: self.msrp_transport.logger.received_illegal_data(line+self.delimiter, self.msrp_transport) return # drop connection? if msrp != "MSRP": self.msrp_transport.logger.received_illegal_data(line+self.delimiter, self.msrp_transport) return # drop connection? method, code, comment = None, None, None rest_sp = rest.split(" ", 1) try: if len(rest_sp[0]) != 3: raise ValueError code = int(rest_sp[0]) except ValueError: # we have a request method = rest_sp[0] else: # we have a response if len(rest_sp) > 1: comment = rest_sp[1] self.data = MSRPData(transaction_id, method, code, comment) self.term_re = re.compile("^-------%s([$#+])$" % re.escape(transaction_id)) self._chunk_header = line+self.delimiter def lineLengthExceeded(self, line): self._reset() def rawDataReceived(self, data): data = self.term_buf + data match = self.term_re.match(data) if match: # we got the last data for this message contents, continuation, extra = match.groups() if contents: self.msrp_transport.logger.received_chunk_data(contents, self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_write(contents, final=True) self.msrp_transport.logger.received_chunk_end('\r\n-------%s%s\r\n' % (self.data.transaction_id, continuation), self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_end(continuation) self._reset() self.setLineMode(extra) else: for term in self.term_substrings: if data.endswith(term): self.term_buf = data[-len(term):] data = data[:-len(term)] break else: self.term_buf = '' self.msrp_transport.logger.received_chunk_data(data, self.msrp_transport, transaction_id=self.data.transaction_id) self.msrp_transport._data_write(data, final=False) def connectionLost(self, reason=connectionDone): self.msrp_transport._connectionLost(reason) _re_uri = re.compile("^(?P.*?)://(((?P.*?)@)?(?P.*?)(:(?P[0-9]+?))?)(/(?P.*?))?;(?P.*?)(;(?P.*))?$") def parse_uri(uri_str): match = _re_uri.match(uri_str) if match is None: raise ParsingError("Cannot parse URI") uri_params = match.groupdict() if uri_params["port"] is not None: uri_params["port"] = int(uri_params["port"]) if uri_params["parameters"] is not None: try: uri_params["parameters"] = dict(param.split("=") for param in uri_params["parameters"].split(";")) except ValueError: raise ParsingError("Cannot parse URI parameters") scheme = uri_params.pop("scheme") if scheme == "msrp": uri_params["use_tls"] = False elif scheme == "msrps": uri_params["use_tls"] = True else: raise ParsingError("Invalid scheme user in URI: %s" % scheme) if uri_params["transport"] != "tcp": raise ParsingError('Invalid transport in URI, only "tcp" is accepted: %s' % uri_params["transport"]) return URI(**uri_params) class ConnectInfo(object): host = None use_tls = True port = 2855 def __init__(self, host=None, use_tls=None, port=None, credentials=None): if host is not None: self.host = host if use_tls is not None: self.use_tls = use_tls if port is not None: self.port = port self.credentials = credentials if self.use_tls and self.credentials is None: from gnutls.interfaces.twisted import X509Credentials self.credentials = X509Credentials(None, None) @property def scheme(self): if self.use_tls: return 'msrps' else: return 'msrp' # use TLS_URI and TCP_URI ? class URI(ConnectInfo): def __init__(self, host=None, use_tls=None, user=None, port=None, session_id=None, transport="tcp", parameters=None, credentials=None): ConnectInfo.__init__(self, host or host_module.default_ip, use_tls=use_tls, port=port, credentials=credentials) self.user = user if session_id is None: session_id = '%x' % random.getrandbits(80) self.session_id = session_id self.transport = transport if parameters is None: self.parameters = {} else: self.parameters = parameters def __repr__(self): params = [self.host, self.use_tls, self.user, self.port, self.session_id, self.transport, self.parameters] defaults = [False, None, None, None, 'tcp', {}] while defaults and params[-1]==defaults[-1]: del params[-1] del defaults[-1] return '%s(%s)' % (self.__class__.__name__, ', '.join(`x` for x in params)) def __str__(self): uri_str = [] if self.use_tls: uri_str.append("msrps://") else: uri_str.append("msrp://") if self.user: uri_str.extend([self.user, "@"]) uri_str.append(self.host) if self.port: uri_str.extend([":", str(self.port)]) if self.session_id: uri_str.extend(["/", self.session_id]) uri_str.extend([";", self.transport]) for key, value in self.parameters.iteritems(): uri_str.extend([";", key, "=", value]) return "".join(uri_str) def __eq__(self, other): """MSRP URI comparison according to section 6.1 of RFC 4975""" if self is other: return True try: if self.use_tls != other.use_tls: return False if self.host.lower() != other.host.lower(): return False if self.port != other.port: return False if self.session_id != other.session_id: return False if self.transport.lower() != other.transport.lower(): return False except AttributeError: return False return True def __ne__(self, other): return not self == other def __hash__(self): return hash((self.use_tls, self.host.lower(), self.port, self.session_id, self.transport.lower())) diff --git a/msrplib/transport.py b/msrplib/transport.py index 7fb33f7..a3a0fa2 100644 --- a/msrplib/transport.py +++ b/msrplib/transport.py @@ -1,336 +1,336 @@ # Copyright (C) 2008-2012 AG Projects. See LICENSE for details import random from application import log from twisted.internet.error import ConnectionDone from eventlib.twistedutil.protocol import GreenTransportBase from msrplib import protocol, MSRPError from msrplib.trafficlog import Logger log = log.get_logger('msrplib') class ChunkParseError(MSRPError): """Failed to parse incoming chunk""" class MSRPTransactionError(MSRPError): def __init__(self, comment=None, code=None): if comment is not None: self.comment = comment if code is not None: self.code = code if not hasattr(self, 'code'): raise TypeError("must provide 'code'") def __str__(self): if hasattr(self, 'comment'): return '%s %s' % (self.code, self.comment) else: return str(self.code) class MSRPBadRequest(MSRPTransactionError): code = 400 comment = 'Bad Request' def __str__(self): return 'Remote party sent bogus data' class MSRPNoSuchSessionError(MSRPTransactionError): code = 481 comment = 'No such session' data_start, data_end, data_write, data_final_write = xrange(4) class MSRPProtocol_withLogging(protocol.MSRPProtocol): _new_chunk = True def rawDataReceived(self, data): self.msrp_transport.logger.report_in(data, self.msrp_transport, self._new_chunk) protocol.MSRPProtocol.rawDataReceived(self, data) def lineReceived(self, line): self.msrp_transport.logger.report_in(line+self.delimiter, self.msrp_transport, self._new_chunk) self._new_chunk = False protocol.MSRPProtocol.lineReceived(self, line) def connectionLost(self, reason): msg = 'Closed connection to %s:%s' % (self.transport.getPeer().host, self.transport.getPeer().port) if not isinstance(reason.value, ConnectionDone): msg += ' (%s)' % reason.getErrorMessage() self.msrp_transport.logger.info(msg) protocol.MSRPProtocol.connectionLost(self, reason) def setLineMode(self, extra): self._new_chunk = True self.msrp_transport.logger.report_in('', self.msrp_transport, packet_done=True) return protocol.MSRPProtocol.setLineMode(self, extra) def make_report(chunk, code, comment): if chunk.success_report == 'yes' or (chunk.failure_report in ('yes', 'partial') and code != 200): report = protocol.MSRPData(transaction_id='%x' % random.getrandbits(64), method='REPORT') report.add_header(protocol.ToPathHeader(chunk.from_path)) report.add_header(protocol.FromPathHeader([chunk.to_path[0]])) report.add_header(protocol.StatusHeader(protocol.Status(code, comment))) report.add_header(protocol.MessageIDHeader(chunk.message_id)) if chunk.byte_range is None: start = 1 total = chunk.size else: start, end, total = chunk.byte_range report.add_header(protocol.ByteRangeHeader(protocol.ByteRange(start, start+chunk.size-1, total))) return report else: return None def make_response(chunk, code, comment): """Construct a response to a request as described in RFC4975 Section 7.2. If the response is not needed, return None. If a required header missing, raise ChunkParseError. """ if chunk.failure_report == 'no': return if chunk.failure_report == 'partial' and code == 200: return if chunk.to_path is None: raise ChunkParseError('missing To-Path header: %r' % chunk) if chunk.from_path is None: raise ChunkParseError('missing From-Path header: %r' % chunk) if chunk.method == 'SEND': to_path = [chunk.from_path[0]] else: to_path = chunk.from_path from_path = [chunk.to_path[0]] response = protocol.MSRPData(chunk.transaction_id, code=code, comment=comment) response.add_header(protocol.ToPathHeader(to_path)) response.add_header(protocol.FromPathHeader(from_path)) return response class MSRPTransport(GreenTransportBase): protocol_class = MSRPProtocol_withLogging def __init__(self, local_uri, logger, use_sessmatch=False): GreenTransportBase.__init__(self) if local_uri is not None and not isinstance(local_uri, protocol.URI): raise TypeError('Not MSRP URI instance: %r' % (local_uri, )) # The following members define To-Path and From-Path headers as following: # * Outgoing request: # From-Path: local_uri # To-Path: local_path + remote_path + [remote_uri] # * Incoming request: # From-Path: remote_path + remote_uri # To-Path: remote_path + local_path + [local_uri] # XXX self.local_uri = local_uri if logger is None: logger = Logger() self.logger = logger self.local_path = [] self.remote_uri = None self.remote_path = [] self.use_sessmatch = use_sessmatch def next_host(self): if self.local_path: return self.local_path[0] return self.full_remote_path[0] def set_local_path(self, lst): self.local_path = lst @property def full_local_path(self): "suitable to put into INVITE" return self.local_path + [self.local_uri] @property def full_remote_path(self): return self.remote_path + [self.remote_uri] def make_request(self, method): transaction_id = '%x' % random.getrandbits(64) chunk = protocol.MSRPData(transaction_id=transaction_id, method=method) chunk.add_header(protocol.ToPathHeader(self.local_path + self.remote_path + [self.remote_uri])) chunk.add_header(protocol.FromPathHeader([self.local_uri])) return chunk def make_send_request(self, message_id=None, data='', start=1, end=None, length=None): chunk = self.make_request('SEND') if end is None: end = start - 1 + len(data) if length is None: length = start - 1 + len(data) if end == length != '*': contflag = '$' else: contflag = '+' chunk.add_header(protocol.ByteRangeHeader(protocol.ByteRange(start, end if length <= 2048 else None, length))) if message_id is None: message_id = '%x' % random.getrandbits(64) chunk.add_header(protocol.MessageIDHeader(message_id)) chunk.data = data chunk.contflag = contflag return chunk def _data_start(self, data): self._queue.send((data_start, data)) def _data_end(self, continuation): self._queue.send((data_end, continuation)) def _data_write(self, contents, final): if final: self._queue.send((data_final_write, contents)) else: self._queue.send((data_write, contents)) def write(self, bytes, wait=True): """Write `bytes' to the socket. If `wait' is true, wait for an operation to complete""" self.logger.report_out(bytes, self.transport) return GreenTransportBase.write(self, bytes, wait) def write_chunk(self, chunk, wait=True): - trailer = chunk.encode_start() - footer = chunk.encode_end(chunk.contflag) - self.write(trailer+chunk.data+footer, wait=wait) - self.logger.sent_new_chunk(trailer, self, chunk=chunk) + header = chunk.encoded_header + footer = chunk.encoded_footer + self.write(header+chunk.data+footer, wait=wait) + self.logger.sent_new_chunk(header, self, chunk=chunk) if chunk.data: self.logger.sent_chunk_data(chunk.data, self, chunk.transaction_id) self.logger.sent_chunk_end(footer, self, chunk.transaction_id) def read_chunk(self, max_size=1024*1024*4): """Wait for a new chunk and return it. If there was an error, close the connection and raise ChunkParseError. In case of unintelligible input, lose the connection and return None. When the connection is closed, raise the reason of the closure (e.g. ConnectionDone). """ assert max_size > 0 func, msrpdata = self._wait() if func!=data_start: self.logger.debug('Bad data: %r %r' % (func, msrpdata)) self.loseConnection() raise ChunkParseError data = msrpdata.data func, param = self._wait() while func == data_write: data += param if len(data) > max_size: self.logger.debug('Chunk is too big (max_size=%d bytes)' % max_size) self.loseConnection() raise ChunkParseError func, param = self._wait() if func == data_final_write: data += param func, param = self._wait() if func != data_end: self.logger.debug('Bad data: %r %s' % (func, repr(param)[:100])) self.loseConnection() raise ChunkParseError if param not in "$+#": self.logger.debug('Bad data: %r %s' % (func, repr(param)[:100])) self.loseConnection() raise ChunkParseError msrpdata.data = data msrpdata.contflag = param self.logger.debug('read_chunk -> %r' % (msrpdata, )) return msrpdata def _set_full_remote_path(self, full_remote_path): "as received in response to INVITE" if not all(isinstance(x, protocol.URI) for x in full_remote_path): raise TypeError('Not all elements are MSRP URI: %r' % full_remote_path) self.remote_uri = full_remote_path[-1] self.remote_path = full_remote_path[:-1] def bind(self, full_remote_path): self._set_full_remote_path(full_remote_path) chunk = self.make_send_request() self.write_chunk(chunk) # With some ACM implementations both parties may think they are active, # so they will both send an empty SEND request. -Saul while True: chunk = self.read_chunk() if chunk.code is None: # This was not a response, it was a request if chunk.method == 'SEND' and not chunk.data: self.write_response(chunk, 200, 'OK') else: self.loseConnection(wait=False) raise MSRPNoSuchSessionError('Chunk received while binding session: %s' % chunk) elif chunk.code != 200: self.loseConnection(wait=False) raise MSRPNoSuchSessionError('Cannot bind session: %s' % chunk) else: break def write_response(self, chunk, code, comment, wait=True): """Generate and write the response, lose the connection in case of error""" try: response = make_response(chunk, code, comment) except ChunkParseError, ex: log.error('Failed to generate a response: %s' % ex) self.loseConnection(wait=False) raise except Exception: log.exception('Failed to generate a response') self.loseConnection(wait=False) raise else: if response is not None: self.write_chunk(response, wait=wait) def accept_binding(self, full_remote_path): self._set_full_remote_path(full_remote_path) chunk = self.read_chunk() error = self.check_incoming_SEND_chunk(chunk) if error is None: code, comment = 200, 'OK' else: code, comment = error.code, error.comment self.write_response(chunk, code, comment) if 'Content-Type' in chunk.headers or chunk.size>0: # deliver chunk to read_chunk data = chunk.data chunk.data = '' self._data_start(chunk) self._data_write(data, final=True) self._data_end(chunk.contflag) def check_incoming_SEND_chunk(self, chunk): """Check the 'To-Path' and 'From-Path' of the incoming SEND chunk. Return None is the paths are valid for this connection. If an error is detected and MSRPError is created and returned. """ assert chunk.method == 'SEND', repr(chunk) if chunk.to_path is None: return MSRPBadRequest('To-Path header missing') if chunk.from_path is None: return MSRPBadRequest('From-Path header missing') to_path = list(chunk.to_path) from_path = list(chunk.from_path) expected_to = [self.local_uri] expected_from = self.local_path + self.remote_path + [self.remote_uri] # Match only session ID when use_sessmatch is set (http://tools.ietf.org/html/draft-ietf-simple-msrp-sessmatch-10) if self.use_sessmatch: if to_path[0].session_id != expected_to[0].session_id: log.error('To-Path: expected session_id %s, got %s' % (expected_to[0].session_id, to_path[0].session_id)) return MSRPNoSuchSessionError('Invalid To-Path') if from_path[0].session_id != expected_from[0].session_id: log.error('From-Path: expected session_id %s, got %s' % (expected_from[0].session_id, from_path[0].session_id)) return MSRPNoSuchSessionError('Invalid From-Path') else: if to_path != expected_to: log.error('To-Path: expected %r, got %r' % (expected_to, to_path)) return MSRPNoSuchSessionError('Invalid To-Path') if from_path != expected_from: log.error('From-Path: expected %r, got %r' % (expected_from, from_path)) return MSRPNoSuchSessionError('Invalid From-Path')