| @@ -9,6 +9,10 @@ from enum import ( | |||||
| ) | ) | ||||
| import logging | import logging | ||||
| import re | import re | ||||
| from random import ( | |||||
| choices, | |||||
| ) | |||||
| import string | |||||
| import socket | import socket | ||||
| from aiodnsresolver import ( | from aiodnsresolver import ( | ||||
| @@ -19,6 +23,7 @@ from aiodnsresolver import ( | |||||
| Message, | Message, | ||||
| Resolver, | Resolver, | ||||
| ResourceRecord, | ResourceRecord, | ||||
| ResolverLoggerAdapter, | |||||
| pack, | pack, | ||||
| parse, | parse, | ||||
| recvfrom, | recvfrom, | ||||
| @@ -40,8 +45,27 @@ def get_logger_default(): | |||||
| return logging.getLogger('dnsrewriteproxy') | return logging.getLogger('dnsrewriteproxy') | ||||
| class DnsProxyLoggerAdapter(logging.LoggerAdapter): | |||||
| def process(self, msg, kwargs): | |||||
| return \ | |||||
| ('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \ | |||||
| ('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs) | |||||
| def get_logger_adapter_default(extra): | |||||
| return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra) | |||||
| def get_resolver_logger_adapter_default(parent_adapter): | |||||
| def _get_resolver_logger_adapter_default(dns_extra): | |||||
| return ResolverLoggerAdapter(parent_adapter, dns_extra) | |||||
| return _get_resolver_logger_adapter_default | |||||
| def DnsProxy( | def DnsProxy( | ||||
| get_resolver=get_resolver_default, get_logger=get_logger_default, | |||||
| get_resolver=get_resolver_default, | |||||
| get_logger_adapter=get_logger_adapter_default, | |||||
| get_resolver_logger_adapter=get_resolver_logger_adapter_default, | |||||
| get_socket=get_socket_default, num_workers=1000, | get_socket=get_socket_default, num_workers=1000, | ||||
| rules=(), | rules=(), | ||||
| ): | ): | ||||
| @@ -53,13 +77,15 @@ def DnsProxy( | |||||
| REFUSED = 5 | REFUSED = 5 | ||||
| loop = get_running_loop() | loop = get_running_loop() | ||||
| logger = get_logger() | |||||
| logger = get_logger_adapter({}) | |||||
| request_id_alphabet = string.ascii_letters + string.digits | |||||
| # The "main" task of the server: it receives incoming requests and puts | # The "main" task of the server: it receives incoming requests and puts | ||||
| # them in a queue that is then fetched from and processed by the proxy | # them in a queue that is then fetched from and processed by the proxy | ||||
| # workers | # workers | ||||
| async def server_worker(sock, resolve, stop): | async def server_worker(sock, resolve, stop): | ||||
| logger.info('Starting') | |||||
| upstream_queue = Queue(maxsize=num_workers) | upstream_queue = Queue(maxsize=num_workers) | ||||
| # We have multiple upstream workers to be able to send multiple | # We have multiple upstream workers to be able to send multiple | ||||
| @@ -71,27 +97,39 @@ def DnsProxy( | |||||
| try: | try: | ||||
| while True: | while True: | ||||
| request_data, addr = await recvfrom(loop, [sock], 512) | request_data, addr = await recvfrom(loop, [sock], 512) | ||||
| await upstream_queue.put((request_data, addr)) | |||||
| request_logger = get_logger_adapter( | |||||
| {'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))}) | |||||
| request_logger.info('Received request from %s', addr) | |||||
| await upstream_queue.put((request_logger, request_data, addr)) | |||||
| except CancelledError: | |||||
| pass | |||||
| except Exception: | |||||
| logger.exception('Error in main loop') | |||||
| finally: | finally: | ||||
| # Finish upstream requests | |||||
| logger.info('Stopping: waiting for requests to finish') | |||||
| await upstream_queue.join() | await upstream_queue.join() | ||||
| logger.info('Stopping: cancelling workers...') | |||||
| for upstream_task in upstream_worker_tasks: | for upstream_task in upstream_worker_tasks: | ||||
| upstream_task.cancel() | upstream_task.cancel() | ||||
| for upstream_task in upstream_worker_tasks: | for upstream_task in upstream_worker_tasks: | ||||
| try: | try: | ||||
| await upstream_task | await upstream_task | ||||
| except CancelledError: | except CancelledError: | ||||
| pass | pass | ||||
| logger.info('Stopping: cancelling workers... (done)') | |||||
| logger.info('Stopping: final cleanup') | |||||
| await stop() | await stop() | ||||
| logger.info('Stopping: done') | |||||
| async def upstream_worker(sock, resolve, upstream_queue): | async def upstream_worker(sock, resolve, upstream_queue): | ||||
| while True: | while True: | ||||
| request_data, addr = await upstream_queue.get() | |||||
| request_logger, request_data, addr = await upstream_queue.get() | |||||
| try: | try: | ||||
| response_data = await get_response_data(resolve, request_data) | |||||
| request_logger.info('Processing request from %s', addr) | |||||
| response_data = await get_response_data(request_logger, resolve, request_data) | |||||
| # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError | # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError | ||||
| # https://stackoverflow.com/a/59794872/1319998 | # https://stackoverflow.com/a/59794872/1319998 | ||||
| sock.sendto(response_data, addr) | sock.sendto(response_data, addr) | ||||
| @@ -100,7 +138,7 @@ def DnsProxy( | |||||
| finally: | finally: | ||||
| upstream_queue.task_done() | upstream_queue.task_done() | ||||
| async def get_response_data(resolve, request_data): | |||||
| async def get_response_data(request_logger, resolve, request_data): | |||||
| # This may raise an exception, which is handled at a higher level. | # This may raise an exception, which is handled at a higher level. | ||||
| # We can't [and I suspect shouldn't try to] return an error to the | # We can't [and I suspect shouldn't try to] return an error to the | ||||
| # client, since we're not able to extract the QID, so the client won't | # client, since we're not able to extract the QID, so the client won't | ||||
| @@ -110,17 +148,18 @@ def DnsProxy( | |||||
| try: | try: | ||||
| return pack( | return pack( | ||||
| error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else | error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else | ||||
| (await proxy(resolve, query)) | |||||
| (await proxy(request_logger, resolve, query)) | |||||
| ) | ) | ||||
| except Exception: | except Exception: | ||||
| logger.exception('Failed to proxy %s', query) | |||||
| request_logger.exception('Failed to proxy %s', query) | |||||
| return pack(error(query, ERRORS.SERVFAIL)) | return pack(error(query, ERRORS.SERVFAIL)) | ||||
| async def proxy(resolve, query): | |||||
| async def proxy(request_logger, resolve, query): | |||||
| name_bytes = query.qd[0].name | name_bytes = query.qd[0].name | ||||
| name_str_lower = query.qd[0].name.lower().decode('idna') | |||||
| request_logger.info('Name: %s', name_bytes) | |||||
| logger.info('%s: received as bytes %s', name_str_lower, name_bytes) | |||||
| name_str_lower = query.qd[0].name.lower().decode('idna') | |||||
| request_logger.info('Decoded: %s', name_bytes) | |||||
| for pattern, replace in rules: | for pattern, replace in rules: | ||||
| rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower) | rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower) | ||||
| @@ -128,20 +167,22 @@ def DnsProxy( | |||||
| break | break | ||||
| else: | else: | ||||
| # No break was triggered, i.e. no match | # No break was triggered, i.e. no match | ||||
| logger.info('%s: does not match a rule', name_str_lower) | |||||
| request_logger.info('Does not match a rule') | |||||
| return error(query, ERRORS.REFUSED) | return error(query, ERRORS.REFUSED) | ||||
| try: | try: | ||||
| ip_addresses = await resolve(rewritten_name_str, TYPES.A) | |||||
| ip_addresses = await resolve( | |||||
| rewritten_name_str, TYPES.A, | |||||
| get_logger_adapter=get_resolver_logger_adapter(request_logger)) | |||||
| except DnsRecordDoesNotExist: | except DnsRecordDoesNotExist: | ||||
| logger.info('%s: does not exist', name_str_lower) | |||||
| request_logger.info('Does not exist') | |||||
| return error(query, ERRORS.NXDOMAIN) | return error(query, ERRORS.NXDOMAIN) | ||||
| except DnsResponseCode as dns_response_code_error: | except DnsResponseCode as dns_response_code_error: | ||||
| logger.info('%s: received error frum upstream %s', | |||||
| name_str_lower, dns_response_code_error.args[0]) | |||||
| request_logger.info('Received error from upstream: %s', | |||||
| name_str_lower, dns_response_code_error.args[0]) | |||||
| return error(query, dns_response_code_error.args[0]) | return error(query, dns_response_code_error.args[0]) | ||||
| logger.info('%s: resolved to %s', name_str_lower, ip_addresses) | |||||
| request_logger.info('Resolved to %s', ip_addresses) | |||||
| now = loop.time() | now = loop.time() | ||||
| def ttl(ip_address): | def ttl(ip_address): | ||||