|
- from asyncio import (
- CancelledError,
- Future,
- Queue,
- create_task,
- get_running_loop,
- )
- from enum import (
- IntEnum,
- )
- import logging
- import re
- import socket
-
- from aiodnsresolver import (
- RESPONSE,
- TYPES,
- DnsRecordDoesNotExist,
- DnsResponseCode,
- Message,
- Resolver,
- ResourceRecord,
- pack,
- parse,
- )
-
-
- def get_socket_default():
- sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
- sock.setblocking(False)
- sock.bind(('', 53))
- return sock
-
-
- def get_resolver_default():
- return Resolver()
-
-
- def get_resolver_with_upstream(upstream):
- async def get_nameservers(_, __):
- for _ in range(0, 5):
- yield (0.5, (upstream, 53))
-
- return Resolver(get_nameservers=get_nameservers)
-
-
- def get_logger_default():
- return logging.getLogger('dnsrewriteproxy')
-
-
- def DnsProxy(
- get_resolver=get_resolver_default, get_logger=get_logger_default,
- get_socket=get_socket_default,
- num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000,
- rules=(),
- ):
-
- class ERRORS(IntEnum):
- FORMERR = 1
- SERVFAIL = 2
- NXDOMAIN = 3
- REFUSED = 5
-
- def __str__(self):
- return self.name
-
- loop = get_running_loop()
- logger = get_logger()
-
- # The "main" task of the server: it receives incoming requests and puts
- # them in a queue that is then fetched from and processed by the proxy
- # workers
-
- async def server_worker(sock, resolve):
- downstream_queue = Queue(maxsize=downstream_queue_maxsize)
- upstream_queue = Queue(maxsize=upstream_queue_maxsize)
-
- # It would "usually" be ok to send downstream from multiple tasks, but
- # if the socket has a full buffer, it would raise a BlockingIOError,
- # and we will need to attach a reader. We can only attach one reader
- # per underlying file, and since we have a single socket, we have a
- # single file. So we send downstream from a single task
- downstream_worker_task = create_task(downstream_worker(sock, downstream_queue))
-
- # We have multiple upstream workers to be able to send multiple
- # requests upstream concurrently, and add responses to downstream_queue
- upstream_worker_tasks = [
- create_task(upstream_worker(resolve, upstream_queue, downstream_queue))
- for _ in range(0, num_workers)]
-
- try:
- while True:
- request_data, addr = await recvfrom(loop, sock, 512)
- await upstream_queue.put((request_data, addr))
- finally:
- # Finish upstream requests, which can add to to the downstream
- # queue
- await upstream_queue.join()
- for upstream_task in upstream_worker_tasks:
- upstream_task.cancel()
-
- # Ensure we have sent the responses downstream
- await downstream_queue.join()
- downstream_worker_task.cancel()
-
- # Wait for the tasks to really be finished
- for upstream_task in upstream_worker_tasks:
- try:
- await upstream_task
- except Exception:
- pass
- try:
- await downstream_worker_task
- except Exception:
- pass
-
- async def upstream_worker(resolve, upstream_queue, downstream_queue):
- while True:
- request_data, addr = await upstream_queue.get()
-
- try:
- response_data = await get_response_data(resolve, request_data)
- except Exception:
- logger.exception('Exception from handler_request_data %s', addr)
- upstream_queue.task_done()
- continue
-
- await downstream_queue.put((response_data, addr))
- upstream_queue.task_done()
-
- async def downstream_worker(sock, downstream_queue):
- while True:
- response_data, addr = await downstream_queue.get()
- await sendto(loop, sock, response_data, addr)
- downstream_queue.task_done()
-
- async def get_response_data(resolve, request_data):
- # This may raise an exception, which is handled at a higher level.
- # We can't [and I suspect shouldn't try to] return an error to the
- # client, since we're not able to extract the QID, so the client won't
- # be able to match it with an outgoing request
- query = parse(request_data)
-
- if not query.qd:
- return pack(error(query, ERRORS.REFUSED))
-
- try:
- return pack(
- error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
- (await proxy(resolve, query))
- )
- except Exception:
- logger.exception('Failed to proxy %s', query)
- return pack(error(query, ERRORS.SERVFAIL))
-
- async def proxy(resolve, query):
- name_bytes = query.qd[0].name
- name_str = query.qd[0].name.decode('idna')
-
- for pattern, replace in rules:
- rewritten_name_str, num_matches = re.subn(pattern, replace, name_str)
- if num_matches:
- break
- else:
- # No break was triggered, i.e. no match
- return error(query, ERRORS.REFUSED)
-
- try:
- ip_addresses = await resolve(rewritten_name_str, TYPES.A)
- except DnsRecordDoesNotExist:
- return error(query, ERRORS.NXDOMAIN)
- except DnsResponseCode as dns_response_code_error:
- return error(query, dns_response_code_error.args[0])
-
- def ttl(ip_address):
- return int(max(0.0, ip_address.expires_at - loop.time()))
-
- reponse_records = tuple(
- ResourceRecord(name=name_bytes, qtype=TYPES.A,
- qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
- for ip_address in ip_addresses
- )
- return Message(
- qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
- qd=query.qd, an=reponse_records, ns=(), ar=(),
- )
-
- async def start():
- # The socket is created synchronously and passed to the server worker,
- # so if there is an error creating it, this function will raise an
- # exception. If no exeption is raise, we are indeed listening#
- sock = get_socket()
-
- # The resolver is also created synchronously, since it can parse
- # /etc/hosts or /etc/resolve.conf, and can raise an exception if
- # something goes wrong with that
- resolve, clear_cache = get_resolver()
- server_worker_task = create_task(server_worker(sock, resolve))
-
- async def stop():
- server_worker_task.cancel()
- try:
- await server_worker_task
- except CancelledError:
- pass
-
- sock.close()
- await clear_cache()
-
- return stop
-
- return start
-
-
- def error(query, rcode):
- return Message(
- qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
- qd=query.qd, an=(), ns=(), ar=(),
- )
-
-
- async def recvfrom(loop, sock, max_bytes):
- try:
- return sock.recvfrom(max_bytes)
- except BlockingIOError:
- pass
-
- def reader():
- try:
- (data, addr) = sock.recvfrom(max_bytes)
- except BlockingIOError:
- pass
- except BaseException as exception:
- loop.remove_reader(fileno)
- if not result.done():
- result.set_exception(exception)
- else:
- loop.remove_reader(fileno)
- if not result.done():
- result.set_result((data, addr))
-
- fileno = sock.fileno()
- result = Future()
- loop.add_reader(fileno, reader)
-
- try:
- return await result
- finally:
- loop.remove_reader(fileno)
-
-
- async def sendto(loop, sock, data, addr):
- # In our cases, the UDP responses will always be 512 bytes or less.
- # Even if sendto sent some of the data, there is no way for the other
- # end to reconstruct their order, so we don't include any logic to send
- # the rest of the data. Since it's UDP, the client already has to have
- # retry logic
-
- try:
- return sock.sendto(data, addr)
- except BlockingIOError:
- pass
-
- def writer():
- try:
- num_bytes = sock.sendto(data, addr)
- except BlockingIOError:
- pass
- except BaseException as exception:
- loop.remove_writer(fileno)
- if not result.done():
- result.set_exception(exception)
- else:
- loop.remove_writer(fileno)
- if not result.done():
- result.set_result(num_bytes)
-
- fileno = sock.fileno()
- result = Future()
- loop.add_writer(fileno, writer)
-
- try:
- return await result
- finally:
- loop.remove_writer(fileno)
|