from asyncio import ( CancelledError, Queue, create_task, get_running_loop, ) from enum import ( IntEnum, ) import logging import re from random import ( choices, ) import string import socket from aiodnsresolver import ( RESPONSE, TYPES, DnsRecordDoesNotExist, DnsResponseCode, Message, Resolver, ResourceRecord, ResolverLoggerAdapter, pack, parse, recvfrom, ) def get_socket_default(): sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) sock.setblocking(False) sock.bind(('', 53)) return sock def get_resolver_default(): return Resolver() def get_logger_default(): return logging.getLogger('dnsrewriteproxy') class DnsProxyLoggerAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return \ ('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \ ('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs) def get_logger_adapter_default(extra): return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra) def get_resolver_logger_adapter_default(parent_adapter): def _get_resolver_logger_adapter_default(dns_extra): return ResolverLoggerAdapter(parent_adapter, dns_extra) return _get_resolver_logger_adapter_default def DnsProxy( get_resolver=get_resolver_default, get_logger_adapter=get_logger_adapter_default, get_resolver_logger_adapter=get_resolver_logger_adapter_default, get_socket=get_socket_default, num_workers=1000, rules=(), ): class ERRORS(IntEnum): FORMERR = 1 SERVFAIL = 2 NXDOMAIN = 3 REFUSED = 5 loop = get_running_loop() logger = get_logger_adapter({}) request_id_alphabet = string.ascii_letters + string.digits # The "main" task of the server: it receives incoming requests and puts # them in a queue that is then fetched from and processed by the proxy # workers async def server_worker(sock, resolve, stop): upstream_queue = Queue(maxsize=num_workers) # We have multiple upstream workers to be able to send multiple # requests upstream concurrently upstream_worker_tasks = [ create_task(upstream_worker(sock, resolve, upstream_queue)) for _ in range(0, num_workers)] try: while True: logger.info('Waiting for next request') request_data, addr = await recvfrom(loop, [sock], 512) request_logger = get_logger_adapter( {'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))}) request_logger.info('Received request from %s', addr) await upstream_queue.put((request_logger, request_data, addr)) finally: logger.info('Stopping: waiting for requests to finish') await upstream_queue.join() logger.info('Stopping: cancelling workers...') for upstream_task in upstream_worker_tasks: upstream_task.cancel() for upstream_task in upstream_worker_tasks: try: await upstream_task except CancelledError: pass logger.info('Stopping: cancelling workers... (done)') logger.info('Stopping: final cleanup') await stop() logger.info('Stopping: done') async def upstream_worker(sock, resolve, upstream_queue): while True: request_logger, request_data, addr = await upstream_queue.get() try: request_logger.info('Processing request') response_data = await get_response_data(request_logger, resolve, request_data) # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError # https://stackoverflow.com/a/59794872/1319998 sock.sendto(response_data, addr) except Exception: request_logger.exception('Error processing request') finally: request_logger.info('Finished processing request') upstream_queue.task_done() async def get_response_data(request_logger, resolve, request_data): # This may raise an exception, which is handled at a higher level. # We can't [and I suspect shouldn't try to] return an error to the # client, since we're not able to extract the QID, so the client won't # be able to match it with an outgoing request query = parse(request_data) try: return pack(await proxy(request_logger, resolve, query)) except Exception: request_logger.exception('Failed to proxy %s', query) return pack(error(query, ERRORS.SERVFAIL)) async def proxy(request_logger, resolve, query): name_bytes = query.qd[0].name request_logger.info('Name: %s', name_bytes) name_str_lower = query.qd[0].name.lower().decode('idna') request_logger.info('Decoded: %s', name_str_lower) if query.qd[0].qtype != TYPES.A: request_logger.info('Unhandled query type: %s', query.qd[0].qtype) return error(query, ERRORS.REFUSED) for pattern, replace in rules: rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower) if num_matches: request_logger.info('Matches rule (%s, %s)', pattern, replace) break else: # No break was triggered, i.e. no match request_logger.info('Does not match a rule') return error(query, ERRORS.REFUSED) try: ip_addresses = await resolve( rewritten_name_str, TYPES.A, get_logger_adapter=get_resolver_logger_adapter(request_logger)) except DnsRecordDoesNotExist: request_logger.info('Does not exist') return error(query, ERRORS.NXDOMAIN) except DnsResponseCode as dns_response_code_error: request_logger.info('Received error from upstream: %s', dns_response_code_error.args[0]) return error(query, dns_response_code_error.args[0]) request_logger.info('Resolved to %s', ip_addresses) now = loop.time() def ttl(ip_address): return int(max(0.0, ip_address.expires_at - now)) reponse_records = tuple( ResourceRecord(name=name_bytes, qtype=TYPES.A, qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed) for ip_address in ip_addresses ) return Message( qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, qd=query.qd, an=reponse_records, ns=(), ar=(), ) async def start(): # The socket is created synchronously and passed to the server worker, # so if there is an error creating it, this function will raise an # exception. If no exeption is raise, we are indeed listening# sock = get_socket() # The resolver is also created synchronously, since it can parse # /etc/hosts or /etc/resolve.conf, and can raise an exception if # something goes wrong with that resolve, clear_cache = get_resolver() async def stop(): sock.close() await clear_cache() return create_task(server_worker(sock, resolve, stop)) return start def error(query, rcode): return Message( qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, qd=query.qd, an=(), ns=(), ar=(), )