| @@ -9,6 +9,10 @@ from enum import ( | |||
| ) | |||
| import logging | |||
| import re | |||
| from random import ( | |||
| choices, | |||
| ) | |||
| import string | |||
| import socket | |||
| from aiodnsresolver import ( | |||
| @@ -19,6 +23,7 @@ from aiodnsresolver import ( | |||
| Message, | |||
| Resolver, | |||
| ResourceRecord, | |||
| ResolverLoggerAdapter, | |||
| pack, | |||
| parse, | |||
| recvfrom, | |||
| @@ -40,8 +45,27 @@ def get_logger_default(): | |||
| return logging.getLogger('dnsrewriteproxy') | |||
| class DnsProxyLoggerAdapter(logging.LoggerAdapter): | |||
| def process(self, msg, kwargs): | |||
| return \ | |||
| ('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \ | |||
| ('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs) | |||
| def get_logger_adapter_default(extra): | |||
| return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra) | |||
| def get_resolver_logger_adapter_default(parent_adapter): | |||
| def _get_resolver_logger_adapter_default(dns_extra): | |||
| return ResolverLoggerAdapter(parent_adapter, dns_extra) | |||
| return _get_resolver_logger_adapter_default | |||
| def DnsProxy( | |||
| get_resolver=get_resolver_default, get_logger=get_logger_default, | |||
| get_resolver=get_resolver_default, | |||
| get_logger_adapter=get_logger_adapter_default, | |||
| get_resolver_logger_adapter=get_resolver_logger_adapter_default, | |||
| get_socket=get_socket_default, num_workers=1000, | |||
| rules=(), | |||
| ): | |||
| @@ -53,13 +77,15 @@ def DnsProxy( | |||
| REFUSED = 5 | |||
| loop = get_running_loop() | |||
| logger = get_logger() | |||
| logger = get_logger_adapter({}) | |||
| request_id_alphabet = string.ascii_letters + string.digits | |||
| # The "main" task of the server: it receives incoming requests and puts | |||
| # them in a queue that is then fetched from and processed by the proxy | |||
| # workers | |||
| async def server_worker(sock, resolve, stop): | |||
| logger.info('Starting') | |||
| upstream_queue = Queue(maxsize=num_workers) | |||
| # We have multiple upstream workers to be able to send multiple | |||
| @@ -71,27 +97,39 @@ def DnsProxy( | |||
| try: | |||
| while True: | |||
| request_data, addr = await recvfrom(loop, [sock], 512) | |||
| await upstream_queue.put((request_data, addr)) | |||
| request_logger = get_logger_adapter( | |||
| {'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))}) | |||
| request_logger.info('Received request from %s', addr) | |||
| await upstream_queue.put((request_logger, request_data, addr)) | |||
| except CancelledError: | |||
| pass | |||
| except Exception: | |||
| logger.exception('Error in main loop') | |||
| finally: | |||
| # Finish upstream requests | |||
| logger.info('Stopping: waiting for requests to finish') | |||
| await upstream_queue.join() | |||
| logger.info('Stopping: cancelling workers...') | |||
| for upstream_task in upstream_worker_tasks: | |||
| upstream_task.cancel() | |||
| for upstream_task in upstream_worker_tasks: | |||
| try: | |||
| await upstream_task | |||
| except CancelledError: | |||
| pass | |||
| logger.info('Stopping: cancelling workers... (done)') | |||
| logger.info('Stopping: final cleanup') | |||
| await stop() | |||
| logger.info('Stopping: done') | |||
| async def upstream_worker(sock, resolve, upstream_queue): | |||
| while True: | |||
| request_data, addr = await upstream_queue.get() | |||
| request_logger, request_data, addr = await upstream_queue.get() | |||
| try: | |||
| response_data = await get_response_data(resolve, request_data) | |||
| request_logger.info('Processing request from %s', addr) | |||
| response_data = await get_response_data(request_logger, resolve, request_data) | |||
| # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError | |||
| # https://stackoverflow.com/a/59794872/1319998 | |||
| sock.sendto(response_data, addr) | |||
| @@ -100,7 +138,7 @@ def DnsProxy( | |||
| finally: | |||
| upstream_queue.task_done() | |||
| async def get_response_data(resolve, request_data): | |||
| async def get_response_data(request_logger, resolve, request_data): | |||
| # This may raise an exception, which is handled at a higher level. | |||
| # We can't [and I suspect shouldn't try to] return an error to the | |||
| # client, since we're not able to extract the QID, so the client won't | |||
| @@ -110,17 +148,18 @@ def DnsProxy( | |||
| try: | |||
| return pack( | |||
| error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else | |||
| (await proxy(resolve, query)) | |||
| (await proxy(request_logger, resolve, query)) | |||
| ) | |||
| except Exception: | |||
| logger.exception('Failed to proxy %s', query) | |||
| request_logger.exception('Failed to proxy %s', query) | |||
| return pack(error(query, ERRORS.SERVFAIL)) | |||
| async def proxy(resolve, query): | |||
| async def proxy(request_logger, resolve, query): | |||
| name_bytes = query.qd[0].name | |||
| name_str_lower = query.qd[0].name.lower().decode('idna') | |||
| request_logger.info('Name: %s', name_bytes) | |||
| logger.info('%s: received as bytes %s', name_str_lower, name_bytes) | |||
| name_str_lower = query.qd[0].name.lower().decode('idna') | |||
| request_logger.info('Decoded: %s', name_bytes) | |||
| for pattern, replace in rules: | |||
| rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower) | |||
| @@ -128,20 +167,22 @@ def DnsProxy( | |||
| break | |||
| else: | |||
| # No break was triggered, i.e. no match | |||
| logger.info('%s: does not match a rule', name_str_lower) | |||
| request_logger.info('Does not match a rule') | |||
| return error(query, ERRORS.REFUSED) | |||
| try: | |||
| ip_addresses = await resolve(rewritten_name_str, TYPES.A) | |||
| ip_addresses = await resolve( | |||
| rewritten_name_str, TYPES.A, | |||
| get_logger_adapter=get_resolver_logger_adapter(request_logger)) | |||
| except DnsRecordDoesNotExist: | |||
| logger.info('%s: does not exist', name_str_lower) | |||
| request_logger.info('Does not exist') | |||
| return error(query, ERRORS.NXDOMAIN) | |||
| except DnsResponseCode as dns_response_code_error: | |||
| logger.info('%s: received error frum upstream %s', | |||
| name_str_lower, dns_response_code_error.args[0]) | |||
| request_logger.info('Received error from upstream: %s', | |||
| name_str_lower, dns_response_code_error.args[0]) | |||
| return error(query, dns_response_code_error.args[0]) | |||
| logger.info('%s: resolved to %s', name_str_lower, ip_addresses) | |||
| request_logger.info('Resolved to %s', ip_addresses) | |||
| now = loop.time() | |||
| def ttl(ip_address): | |||