You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
5.6 KiB

  1. from asyncio import (
  2. CancelledError,
  3. Queue,
  4. create_task,
  5. get_running_loop,
  6. )
  7. from enum import (
  8. IntEnum,
  9. )
  10. import logging
  11. import re
  12. import socket
  13. from aiodnsresolver import (
  14. RESPONSE,
  15. TYPES,
  16. DnsRecordDoesNotExist,
  17. DnsResponseCode,
  18. Message,
  19. Resolver,
  20. ResourceRecord,
  21. pack,
  22. parse,
  23. recvfrom,
  24. )
  25. def get_socket_default():
  26. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  27. sock.setblocking(False)
  28. sock.bind(('', 53))
  29. return sock
  30. def get_resolver_default():
  31. return Resolver()
  32. def get_logger_default():
  33. return logging.getLogger('dnsrewriteproxy')
  34. def DnsProxy(
  35. get_resolver=get_resolver_default, get_logger=get_logger_default,
  36. get_socket=get_socket_default, num_workers=1000,
  37. rules=(),
  38. ):
  39. class ERRORS(IntEnum):
  40. FORMERR = 1
  41. SERVFAIL = 2
  42. NXDOMAIN = 3
  43. REFUSED = 5
  44. loop = get_running_loop()
  45. logger = get_logger()
  46. # The "main" task of the server: it receives incoming requests and puts
  47. # them in a queue that is then fetched from and processed by the proxy
  48. # workers
  49. async def server_worker(sock, resolve):
  50. upstream_queue = Queue(maxsize=num_workers)
  51. # We have multiple upstream workers to be able to send multiple
  52. # requests upstream concurrently, and add responses to downstream_queue
  53. upstream_worker_tasks = [
  54. create_task(upstream_worker(sock, resolve, upstream_queue))
  55. for _ in range(0, num_workers)]
  56. try:
  57. while True:
  58. request_data, addr = await recvfrom(loop, [sock], 512)
  59. await upstream_queue.put((request_data, addr))
  60. finally:
  61. # Finish upstream requests
  62. await upstream_queue.join()
  63. for upstream_task in upstream_worker_tasks:
  64. upstream_task.cancel()
  65. for upstream_task in upstream_worker_tasks:
  66. try:
  67. await upstream_task
  68. except CancelledError:
  69. pass
  70. async def upstream_worker(sock, resolve, upstream_queue):
  71. while True:
  72. request_data, addr = await upstream_queue.get()
  73. try:
  74. response_data = await get_response_data(resolve, request_data)
  75. # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError
  76. # https://stackoverflow.com/a/59794872/1319998
  77. sock.sendto(response_data, addr)
  78. except Exception:
  79. logger.exception('Processing request from %s', addr)
  80. finally:
  81. upstream_queue.task_done()
  82. async def get_response_data(resolve, request_data):
  83. # This may raise an exception, which is handled at a higher level.
  84. # We can't [and I suspect shouldn't try to] return an error to the
  85. # client, since we're not able to extract the QID, so the client won't
  86. # be able to match it with an outgoing request
  87. query = parse(request_data)
  88. try:
  89. return pack(
  90. error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
  91. (await proxy(resolve, query))
  92. )
  93. except Exception:
  94. logger.exception('Failed to proxy %s', query)
  95. return pack(error(query, ERRORS.SERVFAIL))
  96. async def proxy(resolve, query):
  97. name_bytes = query.qd[0].name
  98. name_str_lower = query.qd[0].name.lower().decode('idna')
  99. for pattern, replace in rules:
  100. rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower)
  101. if num_matches:
  102. break
  103. else:
  104. # No break was triggered, i.e. no match
  105. return error(query, ERRORS.REFUSED)
  106. try:
  107. ip_addresses = await resolve(rewritten_name_str, TYPES.A)
  108. except DnsRecordDoesNotExist:
  109. return error(query, ERRORS.NXDOMAIN)
  110. except DnsResponseCode as dns_response_code_error:
  111. return error(query, dns_response_code_error.args[0])
  112. now = loop.time()
  113. def ttl(ip_address):
  114. return int(max(0.0, ip_address.expires_at - now))
  115. reponse_records = tuple(
  116. ResourceRecord(name=name_bytes, qtype=TYPES.A,
  117. qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
  118. for ip_address in ip_addresses
  119. )
  120. return Message(
  121. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  122. qd=query.qd, an=reponse_records, ns=(), ar=(),
  123. )
  124. async def start():
  125. # The socket is created synchronously and passed to the server worker,
  126. # so if there is an error creating it, this function will raise an
  127. # exception. If no exeption is raise, we are indeed listening#
  128. sock = get_socket()
  129. # The resolver is also created synchronously, since it can parse
  130. # /etc/hosts or /etc/resolve.conf, and can raise an exception if
  131. # something goes wrong with that
  132. resolve, clear_cache = get_resolver()
  133. server_worker_task = create_task(server_worker(sock, resolve))
  134. async def stop():
  135. server_worker_task.cancel()
  136. try:
  137. await server_worker_task
  138. except CancelledError:
  139. pass
  140. sock.close()
  141. await clear_cache()
  142. return stop
  143. return start
  144. def error(query, rcode):
  145. return Message(
  146. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  147. qd=query.qd, an=(), ns=(), ar=(),
  148. )