diff --git a/sipsimple/payloads/__init__.py b/sipsimple/payloads/__init__.py
index 6851667b..8039a111 100644
--- a/sipsimple/payloads/__init__.py
+++ b/sipsimple/payloads/__init__.py
@@ -1,1203 +1,1204 @@
__all__ = ['ParserError',
'BuilderError',
'ValidationError',
'IterateTypes',
'IterateIDs',
'IterateItems',
'All',
'parse_qname',
'XMLDocument',
'XMLAttribute',
'XMLElementID',
'XMLElementChild',
'XMLElementChoiceChild',
'XMLStringChoiceChild',
'XMLElement',
'XMLRootElement',
'XMLSimpleElement',
'XMLStringElement',
'XMLLocalizedStringElement',
'XMLBooleanElement',
'XMLByteElement',
'XMLUnsignedByteElement',
'XMLShortElement',
'XMLUnsignedShortElement',
'XMLIntElement',
'XMLUnsignedIntElement',
'XMLLongElement',
'XMLUnsignedLongElement',
'XMLIntegerElement',
'XMLPositiveIntegerElement',
'XMLNegativeIntegerElement',
'XMLNonNegativeIntegerElement',
'XMLNonPositiveIntegerElement',
'XMLDecimalElement',
'XMLDateTimeElement',
'XMLAnyURIElement',
'XMLEmptyElement',
'XMLEmptyElementRegistryType',
'XMLListElement',
'XMLListRootElement',
'XMLStringListElement']
import os
import sys
import urllib.request, urllib.parse, urllib.error
from collections import defaultdict, deque
from copy import deepcopy
from decimal import Decimal
from itertools import chain
from weakref import WeakValueDictionary
from application.python import Null
from application.python.descriptor import classproperty
from application.python.types import MarkerType
from application.python.weakref import weakobjectmap
from lxml import etree
from sipsimple.payloads.datatypes import Boolean, Byte, UnsignedByte, Short, UnsignedShort, Int, UnsignedInt, Long, UnsignedLong
from sipsimple.payloads.datatypes import PositiveInteger, NegativeInteger, NonNegativeInteger, NonPositiveInteger, DateTime, AnyURI
from sipsimple.util import All
## Exceptions
class ParserError(Exception): pass
class BuilderError(Exception): pass
class ValidationError(ParserError): pass
## Markers
class IterateTypes(metaclass=MarkerType): pass
class IterateIDs(metaclass=MarkerType): pass
class IterateItems(metaclass=MarkerType): pass
class StoredAttribute(metaclass=MarkerType): pass
## Utilities
def parse_qname(qname):
if qname[0] == '{':
qname = qname[1:]
return qname.split('}')
else:
return None, qname
## XMLDocument
class XMLDocumentType(type):
def __init__(cls, name, bases, dct):
cls.nsmap = {}
cls.schema_map = {}
cls.element_map = {}
cls.root_element = None
cls.schema = None
cls.parser = None
for base in reversed(bases):
if hasattr(base, 'element_map'):
cls.element_map.update(base.element_map)
if hasattr(base, 'schema_map'):
cls.schema_map.update(base.schema_map)
if hasattr(base, 'nsmap'):
cls.nsmap.update(base.nsmap)
cls._update_schema()
def __setattr__(cls, name, value):
if name == 'schema_path':
if cls is not XMLDocument:
raise AttributeError("%s can only be changed on XMLDocument" % name)
super(XMLDocumentType, cls).__setattr__(name, value)
def update_schema(document_class):
document_class._update_schema()
for document_subclass in document_class.__subclasses__():
update_schema(document_subclass)
update_schema(XMLDocument)
else:
super(XMLDocumentType, cls).__setattr__(name, value)
def _update_schema(cls):
if cls.schema_map:
location_map = {ns: urllib.parse.quote(os.path.abspath(os.path.join(cls.schema_path, schema_file)).replace('\\', '//')) for ns, schema_file in list(cls.schema_map.items())}
schema = """
%s
""" % '\r\n'.join('' % (namespace, schema_location) for namespace, schema_location in list(location_map.items()))
cls.schema = etree.XMLSchema(etree.XML(schema))
cls.parser = etree.XMLParser(schema=cls.schema, remove_blank_text=True)
else:
cls.schema = None
cls.parser = etree.XMLParser(remove_blank_text=True)
class XMLDocument(object, metaclass=XMLDocumentType):
encoding = 'UTF-8'
content_type = None
schema_path = os.path.join(os.path.dirname(__file__), 'xml-schemas')
@classmethod
def parse(cls, document):
try:
- if isinstance(document, str):
- xml = etree.XML(document, parser=cls.parser)
- elif isinstance(document, str):
- xml = etree.XML(document.encode('utf-8'), parser=cls.parser)
+# if isinstance(document, str):
+# xml = etree.XML(document, parser=cls.parser)
+# elif isinstance(document, bytes):
+ if isinstance(document, bytes):
+ xml = etree.XML(document.decode('utf-8'), parser=cls.parser)
else:
xml = etree.parse(document, parser=cls.parser).getroot()
if cls.schema is not None:
cls.schema.assertValid(xml)
return cls.root_element.from_element(xml, xml_document=cls)
except (etree.DocumentInvalid, etree.XMLSyntaxError, ValueError) as e:
raise ParserError(str(e))
@classmethod
def build(cls, root_element, encoding=None, pretty_print=False, validate=True):
if type(root_element) is not cls.root_element:
raise TypeError("can only build XML documents from root elements of type %s" % cls.root_element.__name__)
element = root_element.to_element()
if validate and cls.schema is not None:
cls.schema.assertValid(element)
# Cleanup namespaces and move element NS mappings to the global scope.
normalized_element = etree.Element(element.tag, attrib=element.attrib, nsmap=dict(chain(iter(list(element.nsmap.items())), iter(list(cls.nsmap.items())))))
normalized_element.text = element.text
normalized_element.tail = element.tail
normalized_element.extend(deepcopy(child) for child in element)
etree.cleanup_namespaces(normalized_element)
return etree.tostring(normalized_element, encoding=encoding or cls.encoding, method='xml', xml_declaration=True, pretty_print=pretty_print)
@classmethod
def create(cls, build_kw={}, **kw):
return cls.build(cls.root_element(**kw), **build_kw)
@classmethod
def register_element(cls, xml_class):
cls.element_map[xml_class.qname] = xml_class
for child in cls.__subclasses__():
child.register_element(xml_class)
@classmethod
def get_element(cls, qname, default=None):
return cls.element_map.get(qname, default)
@classmethod
def register_namespace(cls, namespace, prefix=None, schema=None):
if prefix in cls.nsmap:
raise ValueError("prefix %s is already registered in %s" % (prefix, cls.__name__))
if namespace in iter(list(cls.nsmap.values())):
raise ValueError("namespace %s is already registered in %s" % (namespace, cls.__name__))
cls.nsmap[prefix] = namespace
if schema is not None:
cls.schema_map[namespace] = schema
cls._update_schema()
for child in cls.__subclasses__():
child.register_namespace(namespace, prefix, schema)
@classmethod
def unregister_namespace(cls, namespace):
try:
prefix = next((prefix for prefix in cls.nsmap if cls.nsmap[prefix]==namespace))
except StopIteration:
raise KeyError("namespace %s is not registered in %s" % (namespace, cls.__name__))
del cls.nsmap[prefix]
schema = cls.schema_map.pop(namespace, None)
if schema is not None:
cls._update_schema()
for child in cls.__subclasses__():
try:
child.unregister_namespace(namespace)
except KeyError:
pass
## Children descriptors
class XMLAttribute(object):
def __init__(self, name, xmlname=None, type=str, default=None, required=False, test_equal=True, onset=None, ondel=None):
self.name = name
self.xmlname = xmlname or name
self.type = type
self.default = default
self.__xmlparse__ = getattr(type, '__xmlparse__', lambda value: value)
self.__xmlbuild__ = getattr(type, '__xmlbuild__', str)
self.required = required
self.test_equal = test_equal
self.onset = onset
self.ondel = ondel
self.values = weakobjectmap()
def __get__(self, obj, objtype):
if obj is None:
return self
try:
return self.values[obj]
except KeyError:
value = self.values.setdefault(obj, self.default)
if value is not None:
obj.element.set(self.xmlname, self.build(value))
return value
def __set__(self, obj, value):
if value is not None and not isinstance(value, self.type):
value = self.type(value)
old_value = self.values.get(obj, self.default)
if value == old_value:
return
if value is not None:
obj.element.set(self.xmlname, self.build(value))
else:
obj.element.attrib.pop(self.xmlname, None)
self.values[obj] = value
obj.__dirty__ = True
if self.onset:
self.onset(obj, self, value)
def __delete__(self, obj):
obj.element.attrib.pop(self.xmlname, None)
try:
value = self.values.pop(obj)
except KeyError:
pass
else:
if value != self.default:
obj.__dirty__ = True
if self.ondel:
self.ondel(obj, self)
def parse(self, xmlvalue):
return self.__xmlparse__(xmlvalue)
def build(self, value):
return self.__xmlbuild__(value)
class XMLElementID(XMLAttribute):
"""An XMLAttribute that represents the ID of an element (immutable)."""
def __set__(self, obj, value):
if obj in self.values:
raise AttributeError("An XML element ID cannot be changed")
super(XMLElementID, self).__set__(obj, value)
def __delete__(self, obj):
raise AttributeError("An XML element ID cannot be deleted")
class XMLElementChild(object):
def __init__(self, name, type, required=False, test_equal=True, onset=None, ondel=None):
self.name = name
self.type = type
self.required = required
self.test_equal = test_equal
self.onset = onset
self.ondel = ondel
self.values = weakobjectmap()
def __get__(self, obj, objtype):
if obj is None:
return self
try:
return self.values[obj]
except KeyError:
return None
def __set__(self, obj, value):
if value is not None and not isinstance(value, self.type):
value = self.type(value)
same_value = False
old_value = self.values.get(obj)
if value is old_value:
return
elif value is not None and value == old_value:
value.__dirty__ = old_value.__dirty__
same_value = True
if old_value is not None:
obj.element.remove(old_value.element)
if value is not None:
obj._insert_element(value.element)
self.values[obj] = value
if not same_value:
obj.__dirty__ = True
if self.onset:
self.onset(obj, self, value)
def __delete__(self, obj):
try:
old_value = self.values.pop(obj)
except KeyError:
pass
else:
if old_value is not None:
obj.element.remove(old_value.element)
obj.__dirty__ = True
if self.ondel:
self.ondel(obj, self)
class XMLElementChoiceChildWrapper(object):
__slots__ = ('descriptor', 'type')
def __init__(self, descriptor, type):
self.descriptor = descriptor
self.type = type
def __getattribute__(self, name):
if name in ('descriptor', 'type', 'register_extension', 'unregister_extension'):
return super(XMLElementChoiceChildWrapper, self).__getattribute__(name)
else:
return self.descriptor.__getattribute__(name)
def __setattr__(self, name, value):
if name in ('descriptor', 'type'):
super(XMLElementChoiceChildWrapper, self).__setattr__(name, value)
else:
setattr(self.descriptor, name, value)
def __dir__(self):
return dir(self.descriptor) + ['descriptor', 'type', 'register_extension', 'unregister_extension']
def register_extension(self, type):
if self.extension_type is None:
raise ValueError("The %s XML choice element of %s does not support extensions" % (self.name, self.type.__name__))
if not issubclass(type, XMLElement) or not issubclass(type, self.extension_type):
raise TypeError("type is not a subclass of XMLElement and/or %s: %s" % (self.extension_type.__name__, type.__name__))
if type in self.types:
raise ValueError("%s is already registered as a choice extension" % type.__name__)
self.types.add(type)
self.type._xml_children_qname_map[type.qname] = (self.descriptor, type)
for child_class in self.type.__subclasses__():
child_class._xml_children_qname_map[type.qname] = (self.descriptor, type)
def unregister_extension(self, type):
if self.extension_type is None:
raise ValueError("The %s XML choice element of %s does not support extensions" % (self.name, self.type.__name__))
try:
self.types.remove(type)
except ValueError:
raise ValueError("%s is not a registered choice extension on %s" % (type.__name__, self.type.__name__))
del self.type._xml_children_qname_map[type.qname]
for child_class in self.type.__subclasses__():
del child_class._xml_children_qname_map[type.qname]
class XMLElementChoiceChild(object):
def __init__(self, name, types, extension_type=None, required=False, test_equal=True, onset=None, ondel=None):
self.name = name
self.types = set(types)
self.extension_type = extension_type
self.required = required
self.test_equal = test_equal
self.onset = onset
self.ondel = ondel
self.values = weakobjectmap()
def __get__(self, obj, objtype):
if obj is None:
return XMLElementChoiceChildWrapper(self, objtype)
try:
return self.values[obj]
except KeyError:
return None
def __set__(self, obj, value):
if value is not None and type(value) not in self.types:
raise TypeError("%s is not an acceptable type for %s" % (value.__class__.__name__, obj.__class__.__name__))
same_value = False
old_value = self.values.get(obj)
if value is old_value:
return
elif value is not None and value == old_value:
value.__dirty__ = old_value.__dirty__
same_value = True
if old_value is not None:
obj.element.remove(old_value.element)
if value is not None:
obj._insert_element(value.element)
self.values[obj] = value
if not same_value:
obj.__dirty__ = True
if self.onset:
self.onset(obj, self, value)
def __delete__(self, obj):
try:
old_value = self.values.pop(obj)
except KeyError:
pass
else:
if old_value is not None:
obj.element.remove(old_value.element)
obj.__dirty__ = True
if self.ondel:
self.ondel(obj, self)
class XMLStringChoiceChild(XMLElementChoiceChild):
"""
A choice between keyword strings from a registry, custom strings from the
other type and custom extensions. This descriptor will accept and return
strings instead of requiring XMLElement instances for the values in the
registry and the other type. Check XMLEmptyElementRegistryType for a
metaclass for building registries of XMLEmptyElement classes for keywords.
"""
def __init__(self, name, registry=None, other_type=None, extension_type=None):
self.registry = registry
self.other_type = other_type
self.extension_type = extension_type
types = registry.classes if registry is not None else ()
types += (other_type,) if other_type is not None else ()
super(XMLStringChoiceChild, self).__init__(name, types, extension_type=extension_type, required=True, test_equal=True)
def __get__(self, obj, objtype):
value = super(XMLStringChoiceChild, self).__get__(obj, objtype)
if obj is None or objtype is StoredAttribute or value is None or isinstance(value, self.extension_type or ()):
return value
else:
return str(value)
def __set__(self, obj, value):
if isinstance(value, str):
if self.registry is not None and value in self.registry.names:
value = self.registry.class_map[value]()
elif self.other_type is not None:
value = self.other_type.from_string(value)
super(XMLStringChoiceChild, self).__set__(obj, value)
## XMLElement base classes
class XMLElementBase(object):
"""
This class is used as a common ancestor for XML elements and provides
the means for super() to find at least dummy implementations for the
methods that are supposed to be implemented by subclasses, even when
they are not implemented by any other ancestor class. This is necessary
in order to simplify access to these methods when multiple inheritance
is involved and none or only some of the classes implement them.
The methods declared here should to be implemented in subclasses as
necessary.
"""
def __get_dirty__(self):
return False
def __set_dirty__(self, dirty):
return
def _build_element(self):
return
def _parse_element(self, element):
return
class XMLElementType(type):
def __init__(cls, name, bases, dct):
super(XMLElementType, cls).__init__(name, bases, dct)
# set dictionary of xml attributes and xml child elements
cls._xml_attributes = {}
cls._xml_element_children = {}
cls._xml_children_qname_map = {}
for base in reversed(bases):
if hasattr(base, '_xml_attributes'):
cls._xml_attributes.update(base._xml_attributes)
if hasattr(base, '_xml_element_children') and hasattr(base, '_xml_children_qname_map'):
cls._xml_element_children.update(base._xml_element_children)
cls._xml_children_qname_map.update(base._xml_children_qname_map)
for name, value in list(dct.items()):
if isinstance(value, XMLElementID):
if cls._xml_id is not None:
raise AttributeError("Only one XMLElementID attribute can be defined in the %s class" % cls.__name__)
cls._xml_id = value
cls._xml_attributes[value.name] = value
elif isinstance(value, XMLAttribute):
cls._xml_attributes[value.name] = value
elif isinstance(value, XMLElementChild):
cls._xml_element_children[value.name] = value
cls._xml_children_qname_map[value.type.qname] = (value, value.type)
elif isinstance(value, XMLElementChoiceChild):
cls._xml_element_children[value.name] = value
for type in value.types:
cls._xml_children_qname_map[type.qname] = (value, type)
# register class in its XMLDocument
if cls._xml_document is not None:
cls._xml_document.register_element(cls)
class XMLElement(XMLElementBase, metaclass=XMLElementType):
_xml_tag = None # To be defined in subclass
_xml_namespace = None # To be defined in subclass
_xml_document = None # To be defined in subclass
_xml_extension_type = None # Can be defined in subclass
_xml_id = None # Can be defined in subclass, or will be set by the metaclass to the XMLElementID attribute (if present)
_xml_children_order = {} # Can be defined in subclass
# dynamically generated
_xml_attributes = {}
_xml_element_children = {}
_xml_children_qname_map = {}
qname = classproperty(lambda cls: '{%s}%s' % (cls._xml_namespace, cls._xml_tag))
def __init__(self):
self.element = etree.Element(self.qname, nsmap=self._xml_document.nsmap)
self.__dirty__ = True
def __get_dirty__(self):
return (self.__dict__['__dirty__']
or any(child.__dirty__ for child in (getattr(self, name) for name in self._xml_element_children) if child is not None)
or super(XMLElement, self).__get_dirty__())
def __set_dirty__(self, dirty):
super(XMLElement, self).__set_dirty__(dirty)
if not dirty:
for child in (child for child in (getattr(self, name) for name in self._xml_element_children) if child is not None):
child.__dirty__ = dirty
self.__dict__['__dirty__'] = dirty
__dirty__ = property(__get_dirty__, __set_dirty__)
def check_validity(self):
# check attributes
for name, attribute in list(self._xml_attributes.items()):
# if attribute has default but it was not set, will also be added with this occasion
value = getattr(self, name, None)
if attribute.required and value is None:
raise ValidationError("required attribute %s of %s is not set" % (name, self.__class__.__name__))
# check element children
for name, element_child in list(self._xml_element_children.items()):
# if child has default but it was not set, will also be added with this occasion
child = getattr(self, name, None)
if child is None and element_child.required:
raise ValidationError("element child %s of %s is not set" % (name, self.__class__.__name__))
def to_element(self):
try:
self.check_validity()
except ValidationError as e:
raise BuilderError(str(e))
# build element children
for name in self._xml_element_children:
descriptor = getattr(self.__class__, name)
child = descriptor.__get__(self, StoredAttribute)
if child is not None:
child.to_element()
self._build_element()
return self.element
@classmethod
def from_element(cls, element, xml_document=None):
obj = cls.__new__(cls)
obj._xml_document = xml_document if xml_document is not None else cls._xml_document
obj.element = element
# set known attributes
for name, attribute in list(cls._xml_attributes.items()):
xmlvalue = element.get(attribute.xmlname, None)
if xmlvalue is not None:
try:
setattr(obj, name, attribute.parse(xmlvalue))
except (ValueError, TypeError):
raise ValidationError("got illegal value for attribute %s of %s: %s" % (name, cls.__name__, xmlvalue))
# set element children
for child in element:
element_child, type = cls._xml_children_qname_map.get(child.tag, (None, None))
if element_child is not None:
try:
value = type.from_element(child, xml_document=obj._xml_document)
except ValidationError:
pass # we should accept partially valid documents
else:
setattr(obj, element_child.name, value)
obj._parse_element(element)
obj.check_validity()
obj.__dirty__ = False
return obj
@classmethod
def _register_xml_attribute(cls, attribute, element):
cls._xml_element_children[attribute] = element
cls._xml_children_qname_map[element.type.qname] = (element, element.type)
for subclass in cls.__subclasses__():
subclass._register_xml_attribute(attribute, element)
@classmethod
def _unregister_xml_attribute(cls, attribute):
element = cls._xml_element_children.pop(attribute)
del cls._xml_children_qname_map[element.type.qname]
for subclass in cls.__subclasses__():
subclass._unregister_xml_attribute(attribute)
@classmethod
def register_extension(cls, attribute, type, test_equal=True):
if cls._xml_extension_type is None:
raise ValueError("XMLElement type %s does not support extensions (requested extension type %s)" % (cls.__name__, type.__name__))
elif not issubclass(type, cls._xml_extension_type):
raise TypeError("XMLElement type %s only supports extensions of type %s (requested extension type %s)" % (cls.__name__, cls._xml_extension_type, type.__name__))
elif hasattr(cls, attribute):
raise ValueError("XMLElement type %s already has an attribute named %s (requested extension type %s)" % (cls.__name__, attribute, type.__name__))
extension = XMLElementChild(attribute, type=type, required=False, test_equal=test_equal)
setattr(cls, attribute, extension)
cls._register_xml_attribute(attribute, extension)
@classmethod
def unregister_extension(cls, attribute):
if cls._xml_extension_type is None:
raise ValueError("XMLElement type %s does not support extensions" % cls.__name__)
cls._unregister_xml_attribute(attribute)
delattr(cls, attribute)
def _insert_element(self, element):
if element in self.element:
return
order = self._xml_children_order.get(element.tag, self._xml_children_order.get(None, sys.maxsize))
for i in range(len(self.element)):
child_order = self._xml_children_order.get(self.element[i].tag, self._xml_children_order.get(None, sys.maxsize))
if child_order > order:
position = i
break
else:
position = len(self.element)
self.element.insert(position, element)
def __eq__(self, other):
if isinstance(other, XMLElement):
if self is other:
return True
for name, attribute in list(self._xml_attributes.items()):
if attribute.test_equal:
if not hasattr(other, name) or getattr(self, name) != getattr(other, name):
return False
for name, element_child in list(self._xml_element_children.items()):
if element_child.test_equal:
if not hasattr(other, name) or getattr(self, name) != getattr(other, name):
return False
try:
__eq__ = super(XMLElement, self).__eq__
except AttributeError:
return True
else:
return __eq__(other)
elif self._xml_id is not None:
return self._xml_id == other
else:
return NotImplemented
def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal
def __hash__(self):
if self._xml_id is not None:
return hash(self._xml_id)
else:
return object.__hash__(self)
class XMLRootElementType(XMLElementType):
def __init__(cls, name, bases, dct):
super(XMLRootElementType, cls).__init__(name, bases, dct)
if cls._xml_document is not None:
if cls._xml_document.root_element is not None:
raise TypeError('there is already a root element registered for %s' % cls.__name__)
cls._xml_document.root_element = cls
class XMLRootElement(XMLElement, metaclass=XMLRootElementType):
def __init__(self):
XMLElement.__init__(self)
self.__cache__ = WeakValueDictionary({self.element: self})
@classmethod
def from_element(cls, element, xml_document=None):
obj = super(XMLRootElement, cls).from_element(element, xml_document)
obj.__cache__ = WeakValueDictionary({obj.element: obj})
return obj
@classmethod
def parse(cls, document):
return cls._xml_document.parse(document)
def toxml(self, encoding=None, pretty_print=False, validate=True):
return self._xml_document.build(self, encoding=encoding, pretty_print=pretty_print, validate=validate)
def xpath(self, xpath, namespaces=None):
result = []
try:
nodes = self.element.xpath(xpath, namespaces=namespaces)
except etree.XPathError:
raise ValueError("illegal XPath expression")
for element in (node for node in nodes if isinstance(node, etree._Element)):
if element in self.__cache__:
result.append(self.__cache__[element])
continue
if element is self.element:
self.__cache__[element] = self
result.append(self)
continue
for ancestor in element.iterancestors():
if ancestor in self.__cache__:
container = self.__cache__[ancestor]
break
else:
container = self
notvisited = deque([container])
visited = set()
while notvisited:
container = notvisited.popleft()
self.__cache__[container.element] = container
if isinstance(container, XMLListMixin):
children = set(child for child in container if isinstance(child, XMLElement) and child not in visited)
visited.update(children)
notvisited.extend(children)
for child in container._xml_element_children:
value = getattr(container, child)
if value is not None and value not in visited:
visited.add(value)
notvisited.append(value)
if element in self.__cache__:
result.append(self.__cache__[element])
return result
def get_xpath(self, element):
raise NotImplementedError
def find_parent(self, element):
raise NotImplementedError
## Mixin classes
class ThisClass(object):
"""
Special marker class that is used to indicate that an XMLListElement
subclass can be an item of itself. This is necessary because a class
cannot reference itself when defining _xml_item_type
"""
class XMLListMixinType(type):
def __init__(cls, name, bases, dct):
super(XMLListMixinType, cls).__init__(name, bases, dct)
if '_xml_item_type' in dct:
cls._xml_item_type = cls._xml_item_type # trigger __setattr__
def __setattr__(cls, name, value):
if name == '_xml_item_type':
if value is ThisClass:
value = cls
elif isinstance(value, tuple) and ThisClass in value:
value = tuple(cls if type is ThisClass else type for type in value)
if value is None:
cls._xml_item_element_types = ()
cls._xml_item_extension_types = ()
else:
item_types = value if isinstance(value, tuple) else (value,)
cls._xml_item_element_types = tuple(type for type in item_types if issubclass(type, XMLElement))
cls._xml_item_extension_types = tuple(type for type in item_types if not issubclass(type, XMLElement))
super(XMLListMixinType, cls).__setattr__(name, value)
class XMLListMixin(XMLElementBase, metaclass=XMLListMixinType):
"""A mixin representing a list of other XML elements"""
_xml_item_type = None
def __new__(cls, *args, **kw):
if cls._xml_item_type is None:
raise TypeError("The %s class cannot be instantiated because it doesn't define the _xml_item_type attribute" % cls.__name__)
instance = super(XMLListMixin, cls).__new__(cls)
instance._element_map = {}
instance._xmlid_map = defaultdict(dict)
return instance
def __contains__(self, item):
return item in iter(list(self._element_map.values()))
def __iter__(self):
return (self._element_map[element] for element in self.element if element in self._element_map)
def __len__(self):
return len(self._element_map)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, XMLListMixin):
return self is other or (len(self) == len(other) and all(self_item == other_item for self_item, other_item in zip(self, other)))
else:
return NotImplemented
def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal
def __getitem__(self, key):
if key is IterateTypes:
return (cls for cls, mapping in list(self._xmlid_map.items()) if mapping)
if not isinstance(key, tuple):
raise KeyError(key)
try:
cls, id = key
except ValueError:
raise KeyError(key)
if id is IterateIDs:
return iter(list(self._xmlid_map[cls].keys()))
elif id is IterateItems:
return iter(list(self._xmlid_map[cls].values()))
else:
return self._xmlid_map[cls][id]
def __delitem__(self, key):
if not isinstance(key, tuple):
raise KeyError(key)
try:
cls, id = key
except ValueError:
raise KeyError(key)
if id is All:
for item in list(self._xmlid_map[cls].values()):
self.remove(item)
else:
self.remove(self._xmlid_map[cls][id])
def __get_dirty__(self):
return any(item.__dirty__ for item in list(self._element_map.values())) or super(XMLListMixin, self).__get_dirty__()
def __set_dirty__(self, dirty):
super(XMLListMixin, self).__set_dirty__(dirty)
if not dirty:
for item in list(self._element_map.values()):
item.__dirty__ = dirty
def _parse_element(self, element):
super(XMLListMixin, self)._parse_element(element)
self._element_map.clear()
self._xmlid_map.clear()
for child in element[:]:
child_class = self._xml_document.get_element(child.tag, type(None))
if child_class in self._xml_item_element_types or issubclass(child_class, self._xml_item_extension_types):
try:
value = child_class.from_element(child, xml_document=self._xml_document)
except ValidationError:
pass
else:
if value._xml_id is not None and value._xml_id in self._xmlid_map[child_class]:
element.remove(child)
else:
if value._xml_id is not None:
self._xmlid_map[child_class][value._xml_id] = value
self._element_map[value.element] = value
def _build_element(self):
super(XMLListMixin, self)._build_element()
for child in list(self._element_map.values()):
child.to_element()
def add(self, item):
if not (item.__class__ in self._xml_item_element_types or isinstance(item, self._xml_item_extension_types)):
raise TypeError("%s cannot add items of type %s" % (self.__class__.__name__, item.__class__.__name__))
same_value = False
if item._xml_id is not None and item._xml_id in self._xmlid_map[item.__class__]:
old_item = self._xmlid_map[item.__class__][item._xml_id]
if item is old_item:
return
elif item == old_item:
item.__dirty__ = old_item.__dirty__
same_value = True
self.element.remove(old_item.element)
del self._xmlid_map[item.__class__][item._xml_id]
del self._element_map[old_item.element]
self._insert_element(item.element)
if item._xml_id is not None:
self._xmlid_map[item.__class__][item._xml_id] = item
self._element_map[item.element] = item
if not same_value:
self.__dirty__ = True
def remove(self, item):
self.element.remove(item.element)
if item._xml_id is not None:
del self._xmlid_map[item.__class__][item._xml_id]
del self._element_map[item.element]
self.__dirty__ = True
def update(self, sequence):
for item in sequence:
self.add(item)
def clear(self):
for item in list(self._element_map.values()):
self.remove(item)
## Element types
class XMLSimpleElement(XMLElement):
_xml_value_type = None # To be defined in subclass
def __new__(cls, *args, **kw):
if cls._xml_value_type is None:
raise TypeError("The %s class cannot be instantiated because it doesn't define the _xml_value_type attribute" % cls.__name__)
return super(XMLSimpleElement, cls).__new__(cls)
def __init__(self, value):
XMLElement.__init__(self)
self.value = value
def __eq__(self, other):
if isinstance(other, XMLSimpleElement):
return self is other or self.value == other.value
else:
return self.value == other
def __bool__(self):
return bool(self.value)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self.value)
def __str__(self):
return str(self.value)
def __unicode__(self):
return str(self.value)
@property
def value(self):
return self.__dict__['value']
@value.setter
def value(self, value):
if not isinstance(value, self._xml_value_type):
value = self._xml_value_type(value)
if self.__dict__.get('value', Null) == value:
return
self.__dict__['value'] = value
self.__dirty__ = True
def _parse_element(self, element):
super(XMLSimpleElement, self)._parse_element(element)
value = element.text or ''
if hasattr(self._xml_value_type, '__xmlparse__'):
self.value = self._xml_value_type.__xmlparse__(value)
else:
self.value = self._xml_value_type(value)
def _build_element(self):
super(XMLSimpleElement, self)._build_element()
if hasattr(self.value, '__xmlbuild__'):
self.element.text = self.value.__xmlbuild__()
else:
self.element.text = str(self.value)
class XMLStringElement(XMLSimpleElement):
_xml_value_type = str # Can be overwritten in subclasses
def __len__(self):
return len(self.value)
class XMLLocalizedStringElement(XMLStringElement):
lang = XMLAttribute('lang', xmlname='{http://www.w3.org/XML/1998/namespace}lang', type=str, required=False, test_equal=True)
def __init__(self, value, lang=None):
XMLStringElement.__init__(self, value)
self.lang = lang
def __eq__(self, other):
if isinstance(other, XMLLocalizedStringElement):
return self is other or (self.lang == other.lang and self.value == other.value)
elif self.lang is None:
return XMLStringElement.__eq__(self, other)
else:
return NotImplemented
def __repr__(self):
return '%s(%r, %r)' % (self.__class__.__name__, self.value, self.lang)
def _parse_element(self, element):
super(XMLLocalizedStringElement, self)._parse_element(element)
self.lang = element.get('{http://www.w3.org/XML/1998/namespace}lang', None)
class XMLBooleanElement(XMLSimpleElement):
_xml_value_type = Boolean
class XMLByteElement(XMLSimpleElement):
_xml_value_type = Byte
class XMLUnsignedByteElement(XMLSimpleElement):
_xml_value_type = UnsignedByte
class XMLShortElement(XMLSimpleElement):
_xml_value_type = Short
class XMLUnsignedShortElement(XMLSimpleElement):
_xml_value_type = UnsignedShort
class XMLIntElement(XMLSimpleElement):
_xml_value_type = Int
class XMLUnsignedIntElement(XMLSimpleElement):
_xml_value_type = UnsignedInt
class XMLLongElement(XMLSimpleElement):
_xml_value_type = Long
class XMLUnsignedLongElement(XMLSimpleElement):
_xml_value_type = UnsignedLong
class XMLIntegerElement(XMLSimpleElement):
_xml_value_type = int
class XMLPositiveIntegerElement(XMLSimpleElement):
_xml_value_type = PositiveInteger
class XMLNegativeIntegerElement(XMLSimpleElement):
_xml_value_type = NegativeInteger
class XMLNonNegativeIntegerElement(XMLSimpleElement):
_xml_value_type = NonNegativeInteger
class XMLNonPositiveIntegerElement(XMLSimpleElement):
_xml_value_type = NonPositiveInteger
class XMLDecimalElement(XMLSimpleElement):
_xml_value_type = Decimal
class XMLDateTimeElement(XMLSimpleElement):
_xml_value_type = DateTime
class XMLAnyURIElement(XMLStringElement):
_xml_value_type = AnyURI
class XMLEmptyElement(XMLElement):
def __repr__(self):
return '%s()' % self.__class__.__name__
def __eq__(self, other):
return type(self) is type(other) or NotImplemented
def __hash__(self):
return hash(self.__class__)
class XMLEmptyElementRegistryType(type):
"""A metaclass for building registries of XMLEmptyElement subclasses from names"""
def __init__(cls, name, bases, dct):
super(XMLEmptyElementRegistryType, cls).__init__(name, bases, dct)
typename = getattr(cls, '__typename__', name.partition('Registry')[0]).capitalize()
class BaseElementType(XMLEmptyElement):
def __str__(self): return self._xml_tag
def __unicode__(self): return str(self._xml_tag)
cls.__basetype__ = BaseElementType
cls.__basetype__.__name__ = 'Base%sType' % typename
cls.class_map = {}
for name in cls.names:
class ElementType(BaseElementType):
_xml_tag = name
_xml_namespace = cls._xml_namespace
_xml_document = cls._xml_document
_xml_id = name
translation_table = dict.fromkeys(map(ord, '-_'), None)
ElementType.__name__ = typename + name.title().translate(translation_table)
cls.class_map[name] = ElementType
cls.classes = tuple(cls.class_map[name] for name in cls.names)
## Created using mixins
class XMLListElementType(XMLElementType, XMLListMixinType): pass
class XMLListRootElementType(XMLRootElementType, XMLListMixinType): pass
class XMLListElement(XMLElement, XMLListMixin, metaclass=XMLListElementType):
def __bool__(self):
if self._xml_attributes or self._xml_element_children:
return True
else:
return len(self._element_map) != 0
class XMLListRootElement(XMLRootElement, XMLListMixin, metaclass=XMLListRootElementType):
def __bool__(self):
if self._xml_attributes or self._xml_element_children:
return True
else:
return len(self._element_map) != 0
class XMLStringListElementType(XMLListElementType):
def __init__(cls, name, bases, dct):
if cls._xml_item_type is not None:
raise TypeError("The %s class should not define _xml_item_type, but define _xml_item_registry, _xml_item_other_type and _xml_item_extension_type instead" % cls.__name__)
types = cls._xml_item_registry.classes if cls._xml_item_registry is not None else ()
types += tuple(type for type in (cls._xml_item_other_type, cls._xml_item_extension_type) if type is not None)
cls._xml_item_type = types or None
super(XMLStringListElementType, cls).__init__(name, bases, dct)
class XMLStringListElement(XMLListElement, metaclass=XMLStringListElementType):
_xml_item_registry = None
_xml_item_other_type = None
_xml_item_extension_type = None
def __contains__(self, item):
if isinstance(item, str):
if self._xml_item_registry is not None and item in self._xml_item_registry.names:
item = self._xml_item_registry.class_map[item]()
elif self._xml_item_other_type is not None:
item = self._xml_item_other_type.from_string(item)
return item in iter(list(self._element_map.values()))
def __iter__(self):
return (item if isinstance(item, self._xml_item_extension_types) else str(item) for item in super(XMLStringListElement, self).__iter__())
def add(self, item):
if isinstance(item, str):
if self._xml_item_registry is not None and item in self._xml_item_registry.names:
item = self._xml_item_registry.class_map[item]()
elif self._xml_item_other_type is not None:
item = self._xml_item_other_type.from_string(item)
super(XMLStringListElement, self).add(item)
def remove(self, item):
if isinstance(item, str):
if self._xml_item_registry is not None and item in self._xml_item_registry.names:
xmlitem = self._xml_item_registry.class_map[item]()
try:
item = next((entry for entry in super(XMLStringListElement, self).__iter__() if entry == xmlitem))
except StopIteration:
raise KeyError(item)
elif self._xml_item_other_type is not None:
xmlitem = self._xml_item_other_type.from_string(item)
try:
item = next((entry for entry in super(XMLStringListElement, self).__iter__() if entry == xmlitem))
except StopIteration:
raise KeyError(item)
super(XMLStringListElement, self).remove(item)
diff --git a/sipsimple/streams/msrp/chat.py b/sipsimple/streams/msrp/chat.py
index 08650500..d5a4c644 100644
--- a/sipsimple/streams/msrp/chat.py
+++ b/sipsimple/streams/msrp/chat.py
@@ -1,907 +1,905 @@
"""
This module provides classes to parse and generate SDP related to SIP sessions that negotiate Instant Messaging, including CPIM as defined in RFC3862
"""
import pickle as pickle
import codecs
import os
import random
import re
from application.python.descriptor import WriteOnceAttribute
from application.notification import IObserver, NotificationCenter, NotificationData
from application.python import Null
from application.python.types import Singleton
from application.system import openfile
from collections import defaultdict
from email.message import Message as EmailMessage
from email.parser import Parser as EmailParser
from eventlib.coros import queue
from eventlib.proc import spawn, ProcExit
from functools import partial
from msrplib.protocol import FailureReportHeader, SuccessReportHeader, UseNicknameHeader
from msrplib.session import MSRPSession, contains_mime_type
from otr import OTRSession, OTRTransport, OTRState, SMPStatus
from otr.cryptography import DSAPrivateKey
from otr.exceptions import IgnoreMessage, UnencryptedMessage, EncryptedMessageError, OTRError
from zope.interface import implementer
from sipsimple.core import SIPURI, BaseSIPURI
from sipsimple.payloads import ParserError
from sipsimple.payloads.iscomposing import IsComposingDocument, State, LastActive, Refresh, ContentType
from sipsimple.storage import ISIPSimpleApplicationDataStorage
from sipsimple.streams import InvalidStreamError, UnknownStreamError
from sipsimple.streams.msrp import MSRPStreamError, MSRPStreamBase
from sipsimple.threading import run_in_thread, run_in_twisted_thread
from sipsimple.threading.green import run_in_green_thread
from sipsimple.util import MultilingualText, ISOTimestamp
__all__ = ['ChatStream', 'ChatStreamError', 'ChatIdentity', 'CPIMPayload', 'CPIMHeader', 'CPIMNamespace', 'CPIMParserError', 'OTRState', 'SMPStatus']
class OTRTrustedPeer(object):
fingerprint = WriteOnceAttribute() # in order to be hashable this needs to be immutable
def __init__(self, fingerprint, description='', **kw):
if not isinstance(fingerprint, str):
raise TypeError("fingerprint must be a string")
self.fingerprint = fingerprint
self.description = description
self.__dict__.update(kw)
def __hash__(self):
return hash(self.fingerprint)
def __eq__(self, other):
if isinstance(other, OTRTrustedPeer):
return self.fingerprint == other.fingerprint
elif isinstance(other, str):
return self.fingerprint == other
else:
return NotImplemented
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return "{0.__class__.__name__}({0.fingerprint!r}, description={0.description!r})".format(self)
def __reduce__(self):
return self.__class__, (self.fingerprint,), self.__dict__
class OTRTrustedPeerSet(object):
def __init__(self, iterable=()):
self.__data__ = {}
self.update(iterable)
def __repr__(self):
return "{}({})".format(self.__class__.__name__, list(self.__data__.values()))
def __contains__(self, item):
return item in self.__data__
def __getitem__(self, item):
return self.__data__[item]
def __iter__(self):
return iter(list(self.__data__.values()))
def __len__(self):
return len(self.__data__)
def get(self, item, default=None):
return self.__data__.get(item, default)
def add(self, item):
if not isinstance(item, OTRTrustedPeer):
raise TypeError("item should be and instance of OTRTrustedPeer")
self.__data__[item.fingerprint] = item
def remove(self, item):
del self.__data__[item]
def discard(self, item):
self.__data__.pop(item, None)
def update(self, iterable=()):
for item in iterable:
self.add(item)
class OTRCache(object, metaclass=Singleton):
def __init__(self):
from sipsimple.application import SIPApplication
if SIPApplication.storage is None:
raise RuntimeError("Cannot access the OTR cache before SIPApplication.storage is defined")
if ISIPSimpleApplicationDataStorage.providedBy(SIPApplication.storage):
self.key_file = os.path.join(SIPApplication.storage.directory, 'otr.key')
self.trusted_file = os.path.join(SIPApplication.storage.directory, 'otr.trusted')
try:
self.private_key = DSAPrivateKey.load(self.key_file)
if self.private_key.key_size != 1024:
raise ValueError
except (EnvironmentError, ValueError):
self.private_key = DSAPrivateKey.generate()
self.private_key.save(self.key_file)
try:
self.trusted_peers = pickle.load(open(self.trusted_file, 'rb'))
if not isinstance(self.trusted_peers, OTRTrustedPeerSet) or not all(isinstance(item, OTRTrustedPeer) for item in self.trusted_peers):
raise ValueError("invalid OTR trusted peers file")
except Exception:
self.trusted_peers = OTRTrustedPeerSet()
self.save()
else:
self.key_file = self.trusted_file = None
self.private_key = DSAPrivateKey.generate()
self.trusted_peers = OTRTrustedPeerSet()
# def generate_private_key(self):
# self.private_key = DSAPrivateKey.generate()
# if self.key_file:
# self.private_key.save(self.key_file)
@run_in_thread('file-io')
def save(self):
if self.trusted_file is not None:
with openfile(self.trusted_file, 'wb', permissions=0o600) as trusted_file:
pickle.dump(self.trusted_peers, trusted_file)
@implementer(IObserver)
class OTREncryption(object):
def __init__(self, stream):
self.stream = stream
self.otr_cache = OTRCache()
self.otr_session = OTRSession(self.otr_cache.private_key, self.stream, supported_versions={3}) # we need at least OTR-v3 for question based SMP
notification_center = NotificationCenter()
notification_center.add_observer(self, sender=stream)
notification_center.add_observer(self, sender=self.otr_session)
@property
def active(self):
try:
return self.otr_session.encrypted
except AttributeError:
return False
@property
def cipher(self):
return 'AES-128-CTR' if self.active else None
@property
def key_fingerprint(self):
try:
return self.otr_session.local_private_key.public_key.fingerprint
except AttributeError:
return None
@property
def peer_fingerprint(self):
try:
return self.otr_session.remote_public_key.fingerprint
except AttributeError:
return None
@property
def peer_name(self):
try:
return self.__dict__['peer_name']
except KeyError:
trusted_peer = self.otr_cache.trusted_peers.get(self.peer_fingerprint, None)
if trusted_peer is None:
return ''
else:
return self.__dict__.setdefault('peer_name', trusted_peer.description)
@peer_name.setter
def peer_name(self, name):
old_name = self.peer_name
new_name = self.__dict__['peer_name'] = name
if old_name != new_name:
trusted_peer = self.otr_cache.trusted_peers.get(self.peer_fingerprint, None)
if trusted_peer is not None:
trusted_peer.description = new_name
self.otr_cache.save()
notification_center = NotificationCenter()
notification_center.post_notification("ChatStreamOTRPeerNameChanged", sender=self.stream, data=NotificationData(name=name))
@property
def verified(self):
return self.peer_fingerprint in self.otr_cache.trusted_peers
@verified.setter
def verified(self, value):
peer_fingerprint = self.peer_fingerprint
old_verified = peer_fingerprint in self.otr_cache.trusted_peers
new_verified = bool(value)
if peer_fingerprint is None or new_verified == old_verified:
return
if new_verified:
self.otr_cache.trusted_peers.add(OTRTrustedPeer(peer_fingerprint, description=self.peer_name))
else:
self.otr_cache.trusted_peers.remove(peer_fingerprint)
self.otr_cache.save()
notification_center = NotificationCenter()
notification_center.post_notification("ChatStreamOTRVerifiedStateChanged", sender=self.stream, data=NotificationData(verified=new_verified))
@run_in_twisted_thread
def start(self):
if self.otr_session is not None:
self.otr_session.start()
@run_in_twisted_thread
def stop(self):
if self.otr_session is not None:
self.otr_session.stop()
@run_in_twisted_thread
def smp_verify(self, secret, question=None):
self.otr_session.smp_verify(secret, question)
@run_in_twisted_thread
def smp_answer(self, secret):
self.otr_session.smp_answer(secret)
@run_in_twisted_thread
def smp_abort(self):
self.otr_session.smp_abort()
def handle_notification(self, notification):
handler = getattr(self, '_NH_%s' % notification.name, Null)
handler(notification)
def _NH_MediaStreamDidStart(self, notification):
if self.stream.start_otr:
self.otr_session.start()
def _NH_MediaStreamDidEnd(self, notification):
notification.center.remove_observer(self, sender=self.stream)
notification.center.remove_observer(self, sender=self.otr_session)
self.otr_session.stop()
self.otr_session = None
self.stream = None
_NH_MediaStreamDidNotInitialize = _NH_MediaStreamDidEnd
def _NH_OTRSessionStateChanged(self, notification):
notification.center.post_notification('ChatStreamOTREncryptionStateChanged', sender=self.stream, data=notification.data)
def _NH_OTRSessionSMPVerificationDidStart(self, notification):
notification.center.post_notification('ChatStreamSMPVerificationDidStart', sender=self.stream, data=notification.data)
def _NH_OTRSessionSMPVerificationDidNotStart(self, notification):
notification.center.post_notification('ChatStreamSMPVerificationDidNotStart', sender=self.stream, data=notification.data)
def _NH_OTRSessionSMPVerificationDidEnd(self, notification):
notification.center.post_notification('ChatStreamSMPVerificationDidEnd', sender=self.stream, data=notification.data)
class ChatStreamError(MSRPStreamError): pass
class ChatStream(MSRPStreamBase):
type = 'chat'
priority = 1
msrp_session_class = MSRPSession
media_type = 'message'
accept_types = ['message/cpim', 'text/*', 'image/*', 'application/im-iscomposing+xml']
accept_wrapped_types = ['text/*', 'image/*', 'application/im-iscomposing+xml']
prefer_cpim = True
start_otr = True
def __init__(self):
super(ChatStream, self).__init__(direction='sendrecv')
self.message_queue = queue()
self.sent_messages = set()
self.incoming_queue = defaultdict(list)
self.message_queue_thread = None
self.encryption = OTREncryption(self)
@classmethod
def new_from_sdp(cls, session, remote_sdp, stream_index):
remote_stream = remote_sdp.media[stream_index]
if remote_stream.media != 'message':
raise UnknownStreamError
expected_transport = 'TCP/TLS/MSRP' if session.account.msrp.transport=='tls' else 'TCP/MSRP'
if remote_stream.transport != expected_transport:
raise InvalidStreamError("expected %s transport in chat stream, got %s" % (expected_transport, remote_stream.transport))
if remote_stream.formats != ['*']:
raise InvalidStreamError("wrong format list specified")
stream = cls()
stream.remote_role = remote_stream.attributes.getfirst('setup', 'active')
if remote_stream.direction != 'sendrecv':
raise InvalidStreamError("Unsupported direction for chat stream: %s" % remote_stream.direction)
remote_accept_types = remote_stream.attributes.getfirst('accept-types')
if remote_accept_types is None:
raise InvalidStreamError("remote SDP media does not have 'accept-types' attribute")
if not any(contains_mime_type(cls.accept_types, mime_type) for mime_type in remote_accept_types.split()):
raise InvalidStreamError("no compatible media types found")
return stream
@property
def local_identity(self):
try:
return ChatIdentity(self.session.local_identity.uri, self.session.local_identity.display_name)
except AttributeError:
return None
@property
def remote_identity(self):
try:
return ChatIdentity(self.session.remote_identity.uri, self.session.remote_identity.display_name)
except AttributeError:
return None
@property
def private_messages_allowed(self):
return 'private-messages' in self.chatroom_capabilities
@property
def nickname_allowed(self):
return 'nickname' in self.chatroom_capabilities
@property
def chatroom_capabilities(self):
try:
if self.cpim_enabled and self.session.remote_focus:
return ' '.join(self.remote_media.attributes.getall('chatroom')).split()
except AttributeError:
pass
return []
def _NH_MediaStreamDidStart(self, notification):
self.message_queue_thread = spawn(self._message_queue_handler)
def _NH_MediaStreamDidNotInitialize(self, notification):
message_queue, self.message_queue = self.message_queue, queue()
while message_queue:
message = message_queue.wait()
if message.notify_progress:
data = NotificationData(message_id=message.id, message=None, code=0, reason='Stream was closed')
notification.center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
def _NH_MediaStreamDidEnd(self, notification):
if self.message_queue_thread is not None:
self.message_queue_thread.kill()
else:
message_queue, self.message_queue = self.message_queue, queue()
while message_queue:
message = message_queue.wait()
if message.notify_progress:
data = NotificationData(message_id=message.id, message=None, code=0, reason='Stream ended')
notification.center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
def _handle_REPORT(self, chunk):
# in theory, REPORT can come with Byte-Range which would limit the scope of the REPORT to the part of the message.
if chunk.message_id in self.sent_messages:
self.sent_messages.remove(chunk.message_id)
notification_center = NotificationCenter()
data = NotificationData(message_id=chunk.message_id, message=chunk, code=chunk.status.code, reason=chunk.status.comment)
if chunk.status.code == 200:
notification_center.post_notification('ChatStreamDidDeliverMessage', sender=self, data=data)
else:
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
def _handle_SEND(self, chunk):
if chunk.size == 0: # keep-alive
self.msrp_session.send_report(chunk, 200, 'OK')
return
content_type = chunk.content_type.lower()
if not contains_mime_type(self.accept_types, content_type):
self.msrp_session.send_report(chunk, 413, 'Unwanted Message')
return
if chunk.contflag == '#':
self.incoming_queue.pop(chunk.message_id, None)
self.msrp_session.send_report(chunk, 200, 'OK')
return
elif chunk.contflag == '+':
self.incoming_queue[chunk.message_id].append(chunk.data)
self.msrp_session.send_report(chunk, 200, 'OK')
return
else:
data = ''.join(self.incoming_queue.pop(chunk.message_id, [])) + chunk.data
-
+
if content_type == 'message/cpim':
try:
payload = CPIMPayload.decode(data)
except CPIMParserError:
self.msrp_session.send_report(chunk, 400, 'CPIM Parser Error')
return
else:
message = Message(**{name: getattr(payload, name) for name in Message.__slots__})
if not contains_mime_type(self.accept_wrapped_types, message.content_type):
self.msrp_session.send_report(chunk, 413, 'Unwanted Message')
return
if message.timestamp is None:
message.timestamp = ISOTimestamp.now()
if message.sender is None:
message.sender = self.remote_identity
private = self.session.remote_focus and len(message.recipients) == 1 and message.recipients[0] != self.remote_identity
else:
payload = SimplePayload.decode(data, content_type)
message = Message(payload.content, payload.content_type, sender=self.remote_identity, recipients=[self.local_identity], timestamp=ISOTimestamp.now())
private = False
try:
- message.content = self.encryption.otr_session.handle_input(message.content, message.content_type)
+ message.content = self.encryption.otr_session.handle_input(message.content.encode(), message.content_type)
except IgnoreMessage:
self.msrp_session.send_report(chunk, 200, 'OK')
return
except UnencryptedMessage:
encrypted = False
encryption_active = True
except EncryptedMessageError as e:
self.msrp_session.send_report(chunk, 400, str(e))
notification_center = NotificationCenter()
notification_center.post_notification('ChatStreamOTRError', sender=self, data=NotificationData(error=str(e)))
return
except OTRError as e:
self.msrp_session.send_report(chunk, 200, 'OK')
notification_center = NotificationCenter()
notification_center.post_notification('ChatStreamOTRError', sender=self, data=NotificationData(error=str(e)))
return
else:
encrypted = encryption_active = self.encryption.active
if payload.charset is not None:
message.content = message.content.decode(payload.charset)
elif payload.content_type.startswith('text/'):
message.content.decode('utf8')
notification_center = NotificationCenter()
if message.content_type.lower() == IsComposingDocument.content_type:
try:
document = IsComposingDocument.parse(message.content)
except ParserError as e:
self.msrp_session.send_report(chunk, 400, str(e))
return
self.msrp_session.send_report(chunk, 200, 'OK')
data = NotificationData(state=document.state.value,
refresh=document.refresh.value if document.refresh is not None else 120,
content_type=document.content_type.value if document.content_type is not None else None,
last_active=document.last_active.value if document.last_active is not None else None,
sender=message.sender, recipients=message.recipients, private=private,
encrypted=encrypted, encryption_active=encryption_active)
notification_center.post_notification('ChatStreamGotComposingIndication', sender=self, data=data)
else:
self.msrp_session.send_report(chunk, 200, 'OK')
data = NotificationData(message=message, private=private, encrypted=encrypted, encryption_active=encryption_active)
notification_center.post_notification('ChatStreamGotMessage', sender=self, data=data)
def _on_transaction_response(self, message_id, response):
if message_id in self.sent_messages and response.code != 200:
self.sent_messages.remove(message_id)
data = NotificationData(message_id=message_id, message=response, code=response.code, reason=response.comment)
NotificationCenter().post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
def _on_nickname_transaction_response(self, message_id, response):
notification_center = NotificationCenter()
if response.code == 200:
notification_center.post_notification('ChatStreamDidSetNickname', sender=self, data=NotificationData(message_id=message_id, response=response))
else:
notification_center.post_notification('ChatStreamDidNotSetNickname', sender=self, data=NotificationData(message_id=message_id, message=response, code=response.code, reason=response.comment))
def _message_queue_handler(self):
notification_center = NotificationCenter()
try:
while True:
message = self.message_queue.wait()
if self.msrp_session is None:
if message.notify_progress:
data = NotificationData(message_id=message.id, message=None, code=0, reason='Stream ended')
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
break
try:
if isinstance(message.content, str):
message.content = message.content.encode('utf8')
charset = 'utf8'
else:
charset = None
if not isinstance(message, QueuedOTRInternalMessage):
try:
message.content = self.encryption.otr_session.handle_output(message.content, message.content_type)
except OTRError as e:
raise ChatStreamError(str(e))
message.sender = message.sender or self.local_identity
message.recipients = message.recipients or [self.remote_identity]
# check if we MUST use CPIM
need_cpim = (message.sender != self.local_identity or message.recipients != [self.remote_identity] or
message.courtesy_recipients or message.subject or message.timestamp or message.required or message.additional_headers)
if need_cpim or not contains_mime_type(self.remote_accept_types, message.content_type):
if not contains_mime_type(self.remote_accept_wrapped_types, message.content_type):
raise ChatStreamError('Unsupported content_type for outgoing message: %r' % message.content_type)
if not self.cpim_enabled:
raise ChatStreamError('Additional message meta-data cannot be sent, because the CPIM wrapper is not used')
if not self.private_messages_allowed and message.recipients != [self.remote_identity]:
raise ChatStreamError('The remote end does not support private messages')
if message.timestamp is None:
message.timestamp = ISOTimestamp.now()
payload = CPIMPayload(charset=charset, **{name: getattr(message, name) for name in Message.__slots__})
elif self.prefer_cpim and self.cpim_enabled and contains_mime_type(self.remote_accept_wrapped_types, message.content_type):
if message.timestamp is None:
message.timestamp = ISOTimestamp.now()
payload = CPIMPayload(charset=charset, **{name: getattr(message, name) for name in Message.__slots__})
else:
payload = SimplePayload(message.content, message.content_type, charset)
except ChatStreamError as e:
if message.notify_progress:
data = NotificationData(message_id=message.id, message=None, code=0, reason=e.args[0])
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
continue
else:
content, content_type = payload.encode()
message_id = message.id
notify_progress = message.notify_progress
report = 'yes' if notify_progress else 'no'
chunk = self.msrp_session.make_message(content, content_type=content_type, message_id=message_id)
chunk.add_header(FailureReportHeader(report))
chunk.add_header(SuccessReportHeader(report))
try:
self.msrp_session.send_chunk(chunk, response_cb=partial(self._on_transaction_response, message_id))
except Exception as e:
if notify_progress:
data = NotificationData(message_id=message_id, message=None, code=0, reason=str(e))
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
except ProcExit:
if notify_progress:
data = NotificationData(message_id=message_id, message=None, code=0, reason='Stream ended')
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
raise
else:
if notify_progress:
self.sent_messages.add(message_id)
notification_center.post_notification('ChatStreamDidSendMessage', sender=self, data=NotificationData(message=chunk))
finally:
self.message_queue_thread = None
while self.sent_messages:
message_id = self.sent_messages.pop()
data = NotificationData(message_id=message_id, message=None, code=0, reason='Stream ended')
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
message_queue, self.message_queue = self.message_queue, queue()
while message_queue:
message = message_queue.wait()
if message.notify_progress:
data = NotificationData(message_id=message.id, message=None, code=0, reason='Stream ended')
notification_center.post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
@run_in_twisted_thread
def _enqueue_message(self, message):
if self._done:
if message.notify_progress:
data = NotificationData(message_id=message.id, message=None, code=0, reason='Stream ended')
NotificationCenter().post_notification('ChatStreamDidNotDeliverMessage', sender=self, data=data)
else:
self.message_queue.send(message)
@run_in_green_thread
def _set_local_nickname(self, nickname, message_id):
if self.msrp_session is None:
# should we generate ChatStreamDidNotSetNickname here?
return
chunk = self.msrp.make_request('NICKNAME')
chunk.add_header(UseNicknameHeader(nickname or ''))
try:
self.msrp_session.send_chunk(chunk, response_cb=partial(self._on_nickname_transaction_response, message_id))
except Exception as e:
self._failure_reason = str(e)
NotificationCenter().post_notification('MediaStreamDidFail', sender=self, data=NotificationData(context='sending', reason=self._failure_reason))
def inject_otr_message(self, data):
message = QueuedOTRInternalMessage(data)
self._enqueue_message(message)
def send_message(self, content, content_type='text/plain', recipients=None, courtesy_recipients=None, subject=None, timestamp=None, required=None, additional_headers=None):
message = QueuedMessage(content, content_type, recipients=recipients, courtesy_recipients=courtesy_recipients, subject=subject, timestamp=timestamp, required=required, additional_headers=additional_headers, notify_progress=True)
self._enqueue_message(message)
return message.id
def send_composing_indication(self, state, refresh=None, last_active=None, recipients=None):
content = IsComposingDocument.create(state=State(state), refresh=Refresh(refresh) if refresh is not None else None, last_active=LastActive(last_active) if last_active is not None else None, content_type=ContentType('text'))
message = QueuedMessage(content, IsComposingDocument.content_type, recipients=recipients, notify_progress=False)
self._enqueue_message(message)
return message.id
def set_local_nickname(self, nickname):
if not self.nickname_allowed:
raise ChatStreamError('Setting nickname is not supported')
message_id = '%x' % random.getrandbits(64)
self._set_local_nickname(nickname, message_id)
return message_id
OTRTransport.register(ChatStream)
# Chat related objects, including CPIM support as defined in RFC3862
#
class ChatIdentity(object):
_format_re = re.compile(r'^(?:"?(?P[^<]*[^"\s])"?)?\s*<(?Psips?:.+)>$')
def __init__(self, uri, display_name=None):
+ print('ChatIdentity %s %s %s %s' % (uri, type(uri), display_name, type(display_name)))
+
self.uri = uri
self.display_name = display_name
def __eq__(self, other):
if isinstance(other, ChatIdentity):
return self.uri.user == other.uri.user and self.uri.host == other.uri.host
elif isinstance(other, BaseSIPURI):
return self.uri.user == other.user and self.uri.host == other.host
elif isinstance(other, str):
try:
other_uri = SIPURI.parse(other)
except Exception:
return False
else:
return self.uri.user == other_uri.user and self.uri.host == other_uri.host
else:
return NotImplemented
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return '{0.__class__.__name__}(uri={0.uri!r}, display_name={0.display_name!r})'.format(self)
def __str__(self):
- return self.__unicode__().encode('utf-8')
-
- def __unicode__(self):
if self.display_name:
return '{0.display_name} <{0.uri}>'.format(self)
else:
return '<{0.uri}>'.format(self)
@classmethod
def parse(cls, value):
match = cls._format_re.match(value)
if match is None:
raise ValueError('Cannot parse identity value: %r' % value)
return cls(SIPURI.parse(match.group('uri')), match.group('display_name'))
class Message(object):
__slots__ = 'content', 'content_type', 'sender', 'recipients', 'courtesy_recipients', 'subject', 'timestamp', 'required', 'additional_headers'
def __init__(self, content, content_type, sender=None, recipients=None, courtesy_recipients=None, subject=None, timestamp=None, required=None, additional_headers=None):
self.content = content
self.content_type = content_type
self.sender = sender
self.recipients = recipients or []
self.courtesy_recipients = courtesy_recipients or []
self.subject = subject
self.timestamp = ISOTimestamp(timestamp) if timestamp is not None else None
self.required = required or []
self.additional_headers = additional_headers or []
class QueuedMessage(Message):
__slots__ = 'id', 'notify_progress'
def __init__(self, content, content_type, sender=None, recipients=None, courtesy_recipients=None, subject=None, timestamp=None, required=None, additional_headers=None, id=None, notify_progress=True):
super(QueuedMessage, self).__init__(content, content_type, sender, recipients, courtesy_recipients, subject, timestamp, required, additional_headers)
self.id = id or '%x' % random.getrandbits(64)
self.notify_progress = notify_progress
class QueuedOTRInternalMessage(QueuedMessage):
def __init__(self, content):
super(QueuedOTRInternalMessage, self).__init__(content, 'text/plain', notify_progress=False)
class SimplePayload(object):
def __init__(self, content, content_type, charset=None):
if not isinstance(content, bytes):
raise TypeError("content should be an instance of bytes")
self.content = content
self.content_type = content_type
self.charset = charset
def encode(self):
if self.charset is not None:
return self.content, '{0.content_type}; charset="{0.charset}"'.format(self)
else:
return self.content, str(self.content_type)
@classmethod
def decode(cls, content, content_type):
if not isinstance(content, bytes):
raise TypeError("content should be an instance of bytes")
type_helper = EmailParser().parsestr('Content-Type: {}'.format(content_type))
content_type = type_helper.get_content_type()
charset = type_helper.get_content_charset()
return cls(content, content_type, charset)
class CPIMPayload(object):
standard_namespace = 'urn:ietf:params:cpim-headers:'
headers_re = re.compile(r'(?:([^:]+?)\.)?(.+?):\s*(.+?)(?:\r\n|$)')
subject_re = re.compile(r'^(?:;lang=([a-z]{1,8}(?:-[a-z0-9]{1,8})*)\s+)?(.*)$')
namespace_re = re.compile(r'^(?:(\S+) ?)?<(.*)>$')
def __init__(self, content, content_type, charset=None, sender=None, recipients=None, courtesy_recipients=None, subject=None, timestamp=None, required=None, additional_headers=None):
- if not isinstance(content, bytes):
- raise TypeError("content should be an instance of bytes")
self.content = content
self.content_type = content_type
self.charset = charset
self.sender = sender
self.recipients = recipients or []
self.courtesy_recipients = courtesy_recipients or []
self.subject = subject if isinstance(subject, (MultilingualText, type(None))) else MultilingualText(subject)
self.timestamp = ISOTimestamp(timestamp) if timestamp is not None else None
self.required = required or []
self.additional_headers = additional_headers or []
def encode(self):
namespaces = {'': CPIMNamespace(self.standard_namespace)}
header_list = []
if self.sender is not None:
+ print('Sender type %s' % type(self.sender))
header_list.append('From: {}'.format(self.sender))
header_list.extend('To: {}'.format(recipient) for recipient in self.recipients)
header_list.extend('cc: {}'.format(recipient) for recipient in self.courtesy_recipients)
if self.subject is not None:
header_list.append('Subject: {}'.format(self.subject))
header_list.extend('Subject:;lang={} {}'.format(language, translation) for language, translation in list(self.subject.translations.items()))
if self.timestamp is not None:
header_list.append('DateTime: {}'.format(self.timestamp))
if self.required:
header_list.append('Required: {}'.format(','.join(self.required)))
for header in self.additional_headers:
if namespaces.get(header.namespace.prefix) != header.namespace:
if header.namespace.prefix:
header_list.append('NS: {0.namespace.prefix} <{0.namespace}>'.format(header))
else:
header_list.append('NS: <{0.namespace}>'.format(header))
namespaces[header.namespace.prefix] = header.namespace
if header.namespace.prefix:
header_list.append('{0.namespace.prefix}.{0.name}: {0.value}'.format(header))
else:
header_list.append('{0.name}: {0.value}'.format(header))
- headers = '\r\n'.join(header.encode('cpim-header') for header in header_list)
+ headers = '\r\n'.join(header_list)
mime_message = EmailMessage()
mime_message.set_payload(self.content)
mime_message.set_type(self.content_type)
if self.charset is not None:
mime_message.set_param('charset', self.charset)
return headers + '\r\n\r\n' + mime_message.as_string(), 'message/cpim'
@classmethod
def decode(cls, message):
- if not isinstance(message, bytes):
- raise TypeError("message should be an instance of bytes")
headers, separator, body = message.partition('\r\n\r\n')
if not separator:
raise CPIMParserError('Invalid CPIM message')
sender = None
recipients = []
courtesy_recipients = []
subject = None
timestamp = None
required = []
additional_headers = []
namespaces = {'': CPIMNamespace(cls.standard_namespace)}
subjects = {}
for prefix, name, value in cls.headers_re.findall(headers):
namespace = namespaces.get(prefix)
+ print('--------------')
+ print(prefix, name, value)
if namespace is None or '.' in name:
continue
try:
- value = value.decode('cpim-header')
+ #value = value.decode('cpim-header')
if namespace == cls.standard_namespace:
if name == 'From':
sender = ChatIdentity.parse(value)
elif name == 'To':
recipients.append(ChatIdentity.parse(value))
elif name == 'cc':
courtesy_recipients.append(ChatIdentity.parse(value))
elif name == 'Subject':
match = cls.subject_re.match(value)
if match is None:
raise ValueError('Illegal Subject header: %r' % value)
lang, subject = match.groups()
# language tags must be ASCII
subjects[str(lang) if lang is not None else None] = subject
elif name == 'DateTime':
timestamp = ISOTimestamp(value)
elif name == 'Required':
required.extend(re.split(r'\s*,\s*', value))
elif name == 'NS':
match = cls.namespace_re.match(value)
if match is None:
raise ValueError('Illegal NS header: %r' % value)
prefix, uri = match.groups()
namespaces[prefix] = CPIMNamespace(uri, prefix)
else:
additional_headers.append(CPIMHeader(name, namespace, value))
else:
additional_headers.append(CPIMHeader(name, namespace, value))
except ValueError:
pass
if None in subjects:
subject = MultilingualText(subjects.pop(None), **subjects)
elif subjects:
subject = MultilingualText(**subjects)
mime_message = EmailParser().parsestr(body)
content_type = mime_message.get_content_type()
if content_type is None:
raise CPIMParserError("CPIM message missing Content-Type MIME header")
content = mime_message.get_payload()
charset = mime_message.get_content_charset()
return cls(content, content_type, charset, sender, recipients, courtesy_recipients, subject, timestamp, required, additional_headers)
class CPIMParserError(Exception): pass
class CPIMNamespace(str):
def __new__(cls, value, prefix=''):
obj = str.__new__(cls, value)
obj.prefix = prefix
return obj
class CPIMHeader(object):
def __init__(self, name, namespace, value):
self.name = name
self.namespace = namespace
self.value = value
class CPIMCodec(codecs.Codec):
character_map = {c: '\\u{:04x}'.format(c) for c in list(range(32)) + [127]}
character_map[ord('\\')] = '\\\\'
@classmethod
def encode(cls, input, errors='strict'):
return input.translate(cls.character_map).encode('utf-8', errors), len(input)
@classmethod
def decode(cls, input, errors='strict'):
return input.decode('utf-8', errors).encode('raw-unicode-escape', errors).decode('unicode-escape', errors), len(input)
def cpim_codec_search(name):
if name.lower() in ('cpim-header', 'cpim_header'):
return codecs.CodecInfo(name='CPIM-header',
encode=CPIMCodec.encode,
decode=CPIMCodec.decode,
incrementalencoder=codecs.IncrementalEncoder,
incrementaldecoder=codecs.IncrementalDecoder,
streamwriter=codecs.StreamWriter,
streamreader=codecs.StreamReader)
codecs.register(cpim_codec_search)
del cpim_codec_search