diff --git a/sipsimple/lookup-gevent.py b/sipsimple/lookup-gevent.py new file mode 100644 index 00000000..0ff7c901 --- /dev/null +++ b/sipsimple/lookup-gevent.py @@ -0,0 +1,551 @@ + +""" +Implements DNS lookups in the context of SIP, STUN and MSRP relay based +on RFC3263 and related standards. This can be used to determine the next +hop(s) and failover for routing of SIP messages and reservation of network +resources prior the starting of a SIP session. +""" + + + +import re +from itertools import chain +from time import time +from urllib.parse import urlparse + +# replace standard select and socket modules with versions from eventlib +from eventlib import coros, proc +import dns.name +import dns.query + +from gevent.resolver.dnspython import resolver as gresolver + +from application.notification import IObserver, NotificationCenter, NotificationData +from application.python import Null, limit +from application.python.decorator import decorator, preserve_signature +from application.python.types import Singleton +from dns import exception, rdatatype +from twisted.internet import reactor +from zope.interface import implementer + +from sipsimple.core import Route +from sipsimple.threading import run_in_twisted_thread +from sipsimple.threading.green import Command, InterruptCommand, run_in_waitable_green_thread + + +def domain_iterator(domain): + """ + A generator which returns the domain and its parent domains. + """ + while domain not in ('.', ''): + yield domain + domain = (domain.split('.', 1)+[''])[1] + + +@decorator +def post_dns_lookup_notifications(func): + @preserve_signature(func) + def wrapper(obj, *args, **kwargs): + notification_center = NotificationCenter() + try: + result = func(obj, *args, **kwargs) + except DNSLookupError as e: + notification_center.post_notification('DNSLookupDidFail', sender=obj, data=NotificationData(error=str(e))) + raise + else: + notification_center.post_notification('DNSLookupDidSucceed', sender=obj, data=NotificationData(result=result)) + return result + return wrapper + + +class DNSLookupError(Exception): + """ + The error raised by DNSLookup when a lookup cannot be performed. + """ + + +class DNSCache(object): + """ + A simple DNS cache which uses twisted's timers to invalidate its expired + data. + """ + def __init__(self): + self.data = {} + + def get(self, key): + return self.data.get(key, None) + + def put(self, key, value): + expiration = value.expiration-time() + if expiration > 0: + self.data[key] = value + reactor.callLater(limit(expiration, max=3600), self.data.pop, key, None) + + def flush(self, key=None): + if key is not None: + self.data.pop(key, None) + else: + self.data = {} + + +class InternalResolver(gresolver.Resolver): + def __init__(self, *args, **kw): + super(InternalResolver, self).__init__(*args, **kw) + if self.domain.to_text().endswith('local.'): + self.domain = dns.name.root + self.search = [item for item in self.search if not item.to_text().endswith('local.')] + + +class DNSResolver(gresolver.Resolver): + """ + The resolver used by DNSLookup. + + The lifetime setting on it applies to all the queries made on this resolver. + Each time a query is performed, its duration is subtracted from the lifetime + value. + """ + + def __init__(self): + gresolver.Resolver.__init__(self, configure=False) + dns_manager = DNSManager() + self.search = dns_manager.search + self.domain = dns_manager.domain + self.nameservers = dns_manager.nameservers + + def query(self, *args, **kw): + start_time = time() + try: + return gresolver.Resolver.query(self, *args, **kw) + finally: + self.lifetime -= min(self.lifetime, time()-start_time) + + +class SRVResult(object): + """ + Internal object used to save the result of SRV queries. + """ + def __init__(self, priority, weight, port, address): + self.priority = priority + self.weight = weight + self.port = port + self.address = address + + +class NAPTRResult(object): + """ + Internal object used to save the result of NAPTR queries. + """ + def __init__(self, service, order, preference, priority, weight, port, address): + self.service = service + self.order = order + self.preference = preference + self.priority = priority + self.weight = weight + self.port = port + self.address = address + + +class DNSLookup(object): + + cache = DNSCache() + + @run_in_waitable_green_thread + @post_dns_lookup_notifications + def lookup_service(self, uri, service, timeout=3.0, lifetime=15.0): + """ + Performs an SRV query to determine the servers used for the specified + service from the domain in uri.host. If this fails and falling back is + supported, also performs an A query on uri.host, returning the default + port of the service along with the IP addresses in the answer. + + The services supported are `stun' and 'msrprelay'. + + The DNSLookupDidSucceed notification contains a result attribute which + is a list of (address, port) tuples. The DNSLookupDidFail notification + contains an error attribute describing the error encountered. + """ + service_srv_record_map = {"stun": ("_stun._udp", 3478, False), + "msrprelay": ("_msrps._tcp", 2855, True)} + log_context = dict(context='lookup_service', service=service, uri=uri) + + try: + service_prefix, service_port, service_fallback = service_srv_record_map[service] + except KeyError: + raise DNSLookupError("Unknown service: %s" % service) + + try: + # If the host part of the URI is an IP address, we will not do any lookup + if re.match("^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", uri.host.decode()): + return [(uri.host.decode(), uri.port or service_port)] + + resolver = DNSResolver() + resolver.cache = self.cache + resolver.timeout = timeout + resolver.lifetime = lifetime + + record_name = '%s.%s' % (service_prefix, uri.host.decode()) + services = self._lookup_srv_records(resolver, [record_name], log_context=log_context) + if services[record_name]: + return [(result.address, result.port) for result in services[record_name]] + elif service_fallback: + addresses = self._lookup_a_records(resolver, [uri.host.decode()], log_context=log_context) + if addresses[uri.host.decode()]: + return [(addr, service_port) for addr in addresses[uri.host.decode()]] + except (gresolver.Timeout, gresolver.NoNameservers): + raise DNSLookupError('Timeout in lookup for %s servers for domain %s' % (service, uri.host.decode())) + else: + raise DNSLookupError('No %s servers found for domain %s' % (service, uri.host.decode())) + + + @run_in_waitable_green_thread + @post_dns_lookup_notifications + def lookup_sip_proxy(self, uri, supported_transports, timeout=3.0, lifetime=15.0): + """ + Performs an RFC 3263 compliant lookup of transport/ip/port combinations + for a particular SIP URI. As arguments it takes a SIPURI object + and a list of supported transports, in order of preference of the + application. It returns a list of Route objects that can be used in + order of preference. + + The DNSLookupDidSucceed notification contains a result attribute which + is a list of Route objects. The DNSLookupDidFail notification contains + an error attribute describing the error encountered. + """ + + naptr_service_transport_map = {"sips+d2t": "tls", + "sip+d2t": "tcp", + "sip+d2u": "udp"} + + transport_service_map = {"udp": "_sip._udp", + "tcp": "_sip._tcp", + "tls": "_sips._tcp"} + + log_context = dict(context='lookup_sip_proxy', uri=uri) + + if not supported_transports: + raise DNSLookupError("No transports are supported") + supported_transports = [transport.lower() for transport in supported_transports] + unknown_transports = set(supported_transports).difference(transport_service_map) + if unknown_transports: + raise DNSLookupError("Unknown transports: %s" % ', '.join(unknown_transports)) + + try: + # If the host part of the URI is an IP address, we will not do any lookup + transport = uri.transport.decode() if isinstance(uri.transport, bytes) else uri.transport + if re.match("^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", uri.host.decode()): + transport = 'tls' if uri.secure else transport.lower() + if transport not in supported_transports: + raise DNSLookupError("IP transport %s dictated by URI is not supported" % transport) + port = uri.port or (5061 if transport=='tls' else 5060) + route = [Route(address=uri.host, port=port, transport=transport)] + return route + + resolver = DNSResolver() + resolver.cache = self.cache + resolver.timeout = timeout + resolver.lifetime = lifetime + + # If the port is specified in the URI, we will only do an A lookup + if uri.port: + transport = 'tls' if uri.secure else transport.lower() + if transport not in supported_transports: + raise DNSLookupError("Host transport %s dictated by URI is not supported" % transport) + addresses = self._lookup_a_records(resolver, [uri.host.decode()], log_context=log_context) + if addresses[uri.host.decode()]: + return [Route(address=addr, port=uri.port, transport=transport, tls_name=uri.host) for addr in addresses[uri.host.decode()]] + + # If the transport was already set as a parameter on the SIP URI, only do SRV lookups + elif 'transport' in uri.parameters: + transport = uri.parameters['transport'].lower() + if transport not in supported_transports: + raise DNSLookupError("Requested lookup for URI with %s transport, but it is not supported" % transport) + if uri.secure and transport != 'tls': + raise DNSLookupError("Requested lookup for SIPS URI, but with %s transport parameter" % transport) + record_name = '%s.%s' % (transport_service_map[transport], uri.host.decode()) + services = self._lookup_srv_records(resolver, [record_name], log_context=log_context) + if services[record_name]: + return [Route(address=result.address, port=result.port, transport=transport, tls_name=uri.host) for result in services[record_name]] + else: + # If SRV lookup fails, try A lookup + addresses = self._lookup_a_records(resolver, [uri.host.decode()], log_context=log_context) + port = 5061 if transport=='tls' else 5060 + if addresses[uri.host.decode()]: + return [Route(address=addr, port=port, transport=transport, tls_name=uri.host) for addr in addresses[uri.host.decode()]] + + # Otherwise, it means we don't have a numeric IP address, a port isn't specified and neither is a transport. So we have to do a full NAPTR lookup + else: + # If the URI is a SIPS URI, we only support the TLS transport. + if uri.secure: + if 'tls' not in supported_transports: + raise DNSLookupError("Requested lookup for SIPS URI, but TLS transport is not supported") + supported_transports = ['tls'] + # First try NAPTR lookup + naptr_services = [service for service, transport in list(naptr_service_transport_map.items()) if transport in supported_transports] + try: + pointers = self._lookup_naptr_record(resolver, uri.host.decode(), naptr_services, log_context=log_context) + except (gresolver.Timeout, gresolver.NoNameservers): + pointers = [] + if pointers: + return [Route(address=result.address, port=result.port, transport=naptr_service_transport_map[result.service], tls_name=uri.host) for result in pointers] + else: + # If that fails, try SRV lookup + routes = [] + for transport in supported_transports: + record_name = '%s.%s' % (transport_service_map[transport], uri.host.decode()) + try: + services = self._lookup_srv_records(resolver, [record_name], log_context=log_context) + except (gresolver.Timeout, gresolver.NoNameservers): + continue + if services[record_name]: + routes.extend(Route(address=result.address, port=result.port, transport=transport, tls_name=uri.host) for result in services[record_name]) + if routes: + return routes + else: + # If SRV lookup fails, try A lookup + transport = 'tls' if uri.secure else 'udp' + if transport in supported_transports: + addresses = self._lookup_a_records(resolver, [uri.host.decode()], log_context=log_context) + port = 5061 if transport=='tls' else 5060 + if addresses[uri.host.decode()]: + return [Route(address=addr, port=port, transport=transport, tls_name=uri.host) for addr in addresses[uri.host.decode()]] + except (gresolver.Timeout, gresolver.NoNameservers): + raise DNSLookupError("Timeout in lookup for routes for SIP URI %s" % uri) + else: + raise DNSLookupError("No routes found for SIP URI %s" % uri) + + @run_in_waitable_green_thread + @post_dns_lookup_notifications + def lookup_xcap_server(self, uri, timeout=3.0, lifetime=15.0): + """ + Performs a TXT query against xcap. and returns all results + that look like HTTP URIs. + """ + log_context = dict(context='lookup_xcap_server', uri=uri) + notification_center = NotificationCenter() + + try: + # If the host part of the URI is an IP address, we cannot not do any lookup + if re.match("^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", uri.host.decode()): + raise DNSLookupError("Cannot perform DNS query because the host is an IP address") + + resolver = DNSResolver() + resolver.cache = self.cache + resolver.timeout = timeout + resolver.lifetime = lifetime + + record_name = 'xcap.%s' % uri.host.decode() + results = [] + try: + answer = resolver.query(record_name, rdatatype.TXT) + except (gresolver.Timeout, gresolver.NoNameservers) as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='TXT', query_name=str(record_name), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + raise + except exception.DNSException as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='TXT', query_name=str(record_name), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + else: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='TXT', query_name=str(record_name), nameservers=resolver.nameservers, answer=answer, error=None, **log_context)) + for result_uri in list(chain(*(r.strings for r in answer.rrset))): + parsed_uri = urlparse(result_uri.decode()) + if parsed_uri.scheme in ('http', 'https') and parsed_uri.netloc: + results.append(result_uri.decode()) + if not results: + raise DNSLookupError('No XCAP servers found for domain %s' % uri.host.decode()) + return results + except (gresolver.Timeout, gresolver.NoNameservers): + raise DNSLookupError('Timeout in lookup for XCAP servers for domain %s' % uri.host.decode()) + + + def _lookup_a_records(self, resolver, hostnames, additional_records=[], log_context={}): + notification_center = NotificationCenter() + additional_addresses = dict((rset.name.to_text(), rset) for rset in additional_records if rset.rdtype == rdatatype.A) + addresses = {} + for hostname in hostnames: + if hostname in additional_addresses: + addresses[hostname] = [r.address for r in additional_addresses[hostname]] + else: + try: + answer = resolver.query(hostname, rdatatype.A) + except (gresolver.Timeout, gresolver.NoNameservers) as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='A', query_name=str(hostname), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + raise + except exception.DNSException as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='A', query_name=str(hostname), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + addresses[hostname] = [] + else: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='A', query_name=str(hostname), nameservers=resolver.nameservers, answer=answer, error=None, **log_context)) + addresses[hostname] = [r.address for r in answer.rrset] + return addresses + + + def _lookup_srv_records(self, resolver, srv_names, additional_records=[], log_context={}): + notification_center = NotificationCenter() + additional_services = dict((rset.name.to_text(), rset) for rset in additional_records if rset.rdtype == rdatatype.SRV) + services = {} + for srv_name in srv_names: + services[srv_name] = [] + if srv_name in additional_services: + addresses = self._lookup_a_records(resolver, [r.target.to_text() for r in additional_services[srv_name]], additional_records) + for record in additional_services[srv_name]: + services[srv_name].extend(SRVResult(record.priority, record.weight, record.port, addr) for addr in addresses.get(record.target.to_text(), ())) + else: + try: + answer = resolver.query(srv_name, rdatatype.SRV) + except (gresolver.Timeout, gresolver.NoNameservers) as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='SRV', query_name=str(srv_name), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + raise + except exception.DNSException as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='SRV', query_name=str(srv_name), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + else: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='SRV', query_name=str(srv_name), nameservers=resolver.nameservers, answer=answer, error=None, **log_context)) + addresses = self._lookup_a_records(resolver, [r.target.to_text() for r in answer.rrset], answer.response.additional, log_context) + for record in answer.rrset: + services[srv_name].extend(SRVResult(record.priority, record.weight, record.port, addr) for addr in addresses.get(record.target.to_text(), ())) + services[srv_name].sort(key=lambda result: (result.priority, -result.weight)) + return services + + + def _lookup_naptr_record(self, resolver, domain, services, log_context={}): + notification_center = NotificationCenter() + pointers = [] + try: + answer = resolver.query(domain, rdatatype.NAPTR) + except (gresolver.Timeout, gresolver.NoNameservers) as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='NAPTR', query_name=str(domain), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + raise + except exception.DNSException as e: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='NAPTR', query_name=str(domain), nameservers=resolver.nameservers, answer=None, error=e, **log_context)) + else: + notification_center.post_notification('DNSLookupTrace', sender=self, data=NotificationData(query_type='NAPTR', query_name=str(domain), nameservers=resolver.nameservers, answer=answer, error=None, **log_context)) + records = [r for r in answer.rrset if r.service.decode().lower() in services] + services = self._lookup_srv_records(resolver, [r.replacement.to_text() for r in records], answer.response.additional, log_context) + + for record in records: + pointers.extend(NAPTRResult(record.service.decode().lower(), record.order, record.preference, r.priority, r.weight, r.port, r.address) for r in services.get(record.replacement.to_text(), ())) + + pointers.sort(key=lambda result: (result.order, result.preference)) + return pointers + + +@implementer(IObserver) +class DNSManager(object, metaclass=Singleton): + + def __init__(self): + try: + default_resolver = InternalResolver() + except (gresolver.NoResolverConfiguration, gresolver.NoNameservers) as e: + default_resolver = Null + + self.search = default_resolver.search + self.domain = default_resolver.domain + self.google_nameservers = ['8.8.8.8', '8.8.4.4'] + self.nameservers = default_resolver.nameservers or [] + self.probed_domain = 'sip2sip.info.' + self._channel = coros.queue() + self._proc = None + self._timer = None + self._wakeup_timer = None + notification_center = NotificationCenter() + notification_center.add_observer(self, name='SystemIPAddressDidChange') + notification_center.add_observer(self, name='SystemDidWakeUpFromSleep') + + @property + def nameservers(self): + return self.__dict__['nameservers'] + + @nameservers.setter + def nameservers(self, value): + old_value = self.__dict__.get('nameservers', Null) + self.__dict__['nameservers'] = value + if old_value is Null: + NotificationCenter().post_notification('DNSResolverDidInitialize', sender=self, data=NotificationData(nameservers=value)) + elif value != old_value: + NotificationCenter().post_notification('DNSNameserversDidChange', sender=self, data=NotificationData(nameservers=value)) + + def start(self): + self._proc = proc.spawn(self._run) + self._channel.send(Command('probe_dns')) + + def stop(self): + if self._proc is not None: + self._proc.kill() + self._proc = None + if self._timer is not None and self._timer.active(): + self._timer.cancel() + self._timer = None + if self._wakeup_timer is not None and self._wakeup_timer.active(): + self._wakeup_timer.cancel() + self._wakeup_timer = None + + def _run(self): + while True: + try: + command = self._channel.wait() + handler = getattr(self, '_CH_%s' % command.name) + handler(command) + except InterruptCommand: + pass + + def _CH_probe_dns(self, command): + if self._timer is not None and self._timer.active(): + self._timer.cancel() + self._timer = None + + try: + resolver = InternalResolver() + except (gresolver.NoResolverConfiguration, gresolver.NoNameservers) as e: + self._timer = reactor.callLater(15, self._channel.send, Command('probe_dns')) + return + + self.domain = resolver.domain + self.search = resolver.search + local_nameservers = resolver.nameservers + # probe local resolver + resolver.timeout = 1 + resolver.lifetime = 3 + try: + answer = resolver.query(self.probed_domain, rdatatype.NAPTR) + if not any(record.rdtype == rdatatype.NAPTR for record in answer.rrset): + raise exception.DNSException("No NAPTR records found") + answer = resolver.query("_sip._udp.%s" % self.probed_domain, rdatatype.SRV) + if not any(record.rdtype == rdatatype.SRV for record in answer.rrset): + raise exception.DNSException("No SRV records found") + except (gresolver.Timeout, gresolver.NoResolverConfiguration, gresolver.NoNameservers, exception.DNSException): + pass + else: + self.nameservers = resolver.nameservers + return + # local resolver failed. probe google resolver + resolver.nameservers = self.google_nameservers + resolver.timeout = 2 + resolver.lifetime = 4 + try: + answer = resolver.query(self.probed_domain, rdatatype.NAPTR) + if not any(record.rdtype == rdatatype.NAPTR for record in answer.rrset): + raise exception.DNSException("No NAPTR records found") + except (gresolver.Timeout, exception.DNSException): + pass + else: + self.nameservers = resolver.nameservers + return + # google resolver failed. fallback to local resolver and schedule another probe for later + self.nameservers = local_nameservers + self._timer = reactor.callLater(15, self._channel.send, Command('probe_dns')) + + @run_in_twisted_thread + def handle_notification(self, notification): + handler = getattr(self, '_NH_%s' % notification.name, Null) + handler(notification) + + def _NH_SystemIPAddressDidChange(self, notification): + self._proc.kill(InterruptCommand) + self._channel.send(Command('probe_dns')) + + def _NH_SystemDidWakeUpFromSleep(self, notification): + if self._wakeup_timer is None: + def wakeup_action(): + self._proc.kill(InterruptCommand) + self._channel.send(Command('probe_dns')) + self._wakeup_timer = None + self._wakeup_timer = reactor.callLater(5, wakeup_action) # wait for system to stabilize + +