|
- from asyncio import (
- CancelledError,
- Queue,
- create_task,
- get_running_loop,
- )
- from enum import (
- IntEnum,
- )
- import logging
- import re
- from random import (
- choices,
- )
- import string
- import socket
-
- from aiodnsresolver import (
- RESPONSE,
- TYPES,
- DnsRecordDoesNotExist,
- DnsResponseCode,
- Message,
- Resolver,
- ResourceRecord,
- ResolverLoggerAdapter,
- pack,
- parse,
- recvfrom,
- )
-
-
- def get_socket_default():
- sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
- sock.setblocking(False)
- sock.bind(('', 53))
- return sock
-
-
- def get_resolver_default():
- return Resolver()
-
-
- class DnsProxyLoggerAdapter(logging.LoggerAdapter):
- def process(self, msg, kwargs):
- return \
- ('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \
- ('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs)
-
-
- def get_logger_adapter_default(extra):
- return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra)
-
-
- def get_resolver_logger_adapter_default(parent_adapter):
- def _get_resolver_logger_adapter_default(dns_extra):
- return ResolverLoggerAdapter(parent_adapter, dns_extra)
- return _get_resolver_logger_adapter_default
-
-
- def DnsProxy(
- get_resolver=get_resolver_default,
- get_logger_adapter=get_logger_adapter_default,
- get_resolver_logger_adapter=get_resolver_logger_adapter_default,
- get_socket=get_socket_default, num_workers=1000,
- rules=(),
- ):
-
- class ERRORS(IntEnum):
- FORMERR = 1
- SERVFAIL = 2
- NXDOMAIN = 3
- REFUSED = 5
-
- loop = get_running_loop()
- logger = get_logger_adapter({})
- request_id_alphabet = string.ascii_letters + string.digits
-
- # The "main" task of the server: it receives incoming requests and puts
- # them in a queue that is then fetched from and processed by the proxy
- # workers
-
- async def server_worker(sock, resolve, stop):
- upstream_queue = Queue(maxsize=num_workers)
-
- # We have multiple upstream workers to be able to send multiple
- # requests upstream concurrently
- upstream_worker_tasks = [
- create_task(upstream_worker(sock, resolve, upstream_queue))
- for _ in range(0, num_workers)]
-
- try:
- while True:
- logger.info('Waiting for next request')
- request_data, addr = await recvfrom(loop, [sock], 512)
- request_logger = get_logger_adapter(
- {'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))})
- request_logger.info('Received request from %s', addr)
- await upstream_queue.put((request_logger, request_data, addr))
- finally:
- logger.info('Stopping: waiting for requests to finish')
- await upstream_queue.join()
-
- logger.info('Stopping: cancelling workers...')
- for upstream_task in upstream_worker_tasks:
- upstream_task.cancel()
- for upstream_task in upstream_worker_tasks:
- try:
- await upstream_task
- except CancelledError:
- pass
- logger.info('Stopping: cancelling workers... (done)')
-
- logger.info('Stopping: final cleanup')
- await stop()
- logger.info('Stopping: done')
-
- async def upstream_worker(sock, resolve, upstream_queue):
- while True:
- request_logger, request_data, addr = await upstream_queue.get()
-
- try:
- request_logger.info('Processing request')
- response_data = await get_response_data(request_logger, resolve, request_data)
- # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError
- # https://stackoverflow.com/a/59794872/1319998
- sock.sendto(response_data, addr)
- except Exception:
- request_logger.exception('Error processing request')
- finally:
- request_logger.info('Finished processing request')
- upstream_queue.task_done()
-
- async def get_response_data(request_logger, resolve, request_data):
- # This may raise an exception, which is handled at a higher level.
- # We can't [and I suspect shouldn't try to] return an error to the
- # client, since we're not able to extract the QID, so the client won't
- # be able to match it with an outgoing request
- query = parse(request_data)
-
- try:
- return pack(await proxy(request_logger, resolve, query))
- except Exception:
- request_logger.exception('Failed to proxy %s', query)
- return pack(error(query, ERRORS.SERVFAIL))
-
- async def proxy(request_logger, resolve, query):
- name_bytes = query.qd[0].name
- request_logger.info('Name: %s', name_bytes)
-
- name_str_lower = query.qd[0].name.lower().decode('idna')
- request_logger.info('Decoded: %s', name_str_lower)
-
- if query.qd[0].qtype != TYPES.A:
- request_logger.info('Unhandled query type: %s', query.qd[0].qtype)
- return error(query, ERRORS.REFUSED)
-
- for pattern, replace in rules:
- rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower)
- if num_matches:
- request_logger.info('Matches rule (%s, %s)', pattern, replace)
- break
- else:
- # No break was triggered, i.e. no match
- request_logger.info('Does not match a rule')
- return error(query, ERRORS.REFUSED)
-
- try:
- ip_addresses = await resolve(
- rewritten_name_str, TYPES.A,
- get_logger_adapter=get_resolver_logger_adapter(request_logger))
- except DnsRecordDoesNotExist:
- request_logger.info('Does not exist')
- return error(query, ERRORS.NXDOMAIN)
- except DnsResponseCode as dns_response_code_error:
- request_logger.info('Received error from upstream: %s',
- dns_response_code_error.args[0])
- return error(query, dns_response_code_error.args[0])
-
- request_logger.info('Resolved to %s', ip_addresses)
- now = loop.time()
-
- def ttl(ip_address):
- return int(max(0.0, ip_address.expires_at - now))
-
- reponse_records = tuple(
- ResourceRecord(name=name_bytes, qtype=TYPES.A,
- qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
- for ip_address in ip_addresses
- )
- return Message(
- qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
- qd=query.qd, an=reponse_records, ns=(), ar=(),
- )
-
- async def start():
- # The socket is created synchronously and passed to the server worker,
- # so if there is an error creating it, this function will raise an
- # exception. If no exeption is raise, we are indeed listening#
- sock = get_socket()
-
- # The resolver is also created synchronously, since it can parse
- # /etc/hosts or /etc/resolve.conf, and can raise an exception if
- # something goes wrong with that
- resolve, clear_cache = get_resolver()
-
- async def stop():
- sock.close()
- await clear_cache()
-
- return create_task(server_worker(sock, resolve, stop))
-
- return start
-
-
- def error(query, rcode):
- return Message(
- qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
- qd=query.qd, an=(), ns=(), ar=(),
- )
|