You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

224 lines
7.5 KiB

  1. from asyncio import (
  2. CancelledError,
  3. Queue,
  4. create_task,
  5. get_running_loop,
  6. )
  7. from enum import (
  8. IntEnum,
  9. )
  10. import logging
  11. import re
  12. from random import (
  13. choices,
  14. )
  15. import string
  16. import socket
  17. from aiodnsresolver import (
  18. RESPONSE,
  19. TYPES,
  20. DnsRecordDoesNotExist,
  21. DnsResponseCode,
  22. Message,
  23. Resolver,
  24. ResourceRecord,
  25. ResolverLoggerAdapter,
  26. pack,
  27. parse,
  28. recvfrom,
  29. )
  30. def get_socket_default():
  31. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  32. sock.setblocking(False)
  33. sock.bind(('', 53))
  34. return sock
  35. def get_resolver_default():
  36. return Resolver()
  37. def get_logger_default():
  38. return logging.getLogger('dnsrewriteproxy')
  39. class DnsProxyLoggerAdapter(logging.LoggerAdapter):
  40. def process(self, msg, kwargs):
  41. return \
  42. ('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \
  43. ('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs)
  44. def get_logger_adapter_default(extra):
  45. return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra)
  46. def get_resolver_logger_adapter_default(parent_adapter):
  47. def _get_resolver_logger_adapter_default(dns_extra):
  48. return ResolverLoggerAdapter(parent_adapter, dns_extra)
  49. return _get_resolver_logger_adapter_default
  50. def DnsProxy(
  51. get_resolver=get_resolver_default,
  52. get_logger_adapter=get_logger_adapter_default,
  53. get_resolver_logger_adapter=get_resolver_logger_adapter_default,
  54. get_socket=get_socket_default, num_workers=1000,
  55. rules=(),
  56. ):
  57. class ERRORS(IntEnum):
  58. FORMERR = 1
  59. SERVFAIL = 2
  60. NXDOMAIN = 3
  61. REFUSED = 5
  62. loop = get_running_loop()
  63. logger = get_logger_adapter({})
  64. request_id_alphabet = string.ascii_letters + string.digits
  65. # The "main" task of the server: it receives incoming requests and puts
  66. # them in a queue that is then fetched from and processed by the proxy
  67. # workers
  68. async def server_worker(sock, resolve, stop):
  69. upstream_queue = Queue(maxsize=num_workers)
  70. # We have multiple upstream workers to be able to send multiple
  71. # requests upstream concurrently
  72. upstream_worker_tasks = [
  73. create_task(upstream_worker(sock, resolve, upstream_queue))
  74. for _ in range(0, num_workers)]
  75. try:
  76. while True:
  77. logger.info('Waiting for next request')
  78. request_data, addr = await recvfrom(loop, [sock], 512)
  79. request_logger = get_logger_adapter(
  80. {'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))})
  81. request_logger.info('Received request from %s', addr)
  82. await upstream_queue.put((request_logger, request_data, addr))
  83. finally:
  84. logger.info('Stopping: waiting for requests to finish')
  85. await upstream_queue.join()
  86. logger.info('Stopping: cancelling workers...')
  87. for upstream_task in upstream_worker_tasks:
  88. upstream_task.cancel()
  89. for upstream_task in upstream_worker_tasks:
  90. try:
  91. await upstream_task
  92. except CancelledError:
  93. pass
  94. logger.info('Stopping: cancelling workers... (done)')
  95. logger.info('Stopping: final cleanup')
  96. await stop()
  97. logger.info('Stopping: done')
  98. async def upstream_worker(sock, resolve, upstream_queue):
  99. while True:
  100. request_logger, request_data, addr = await upstream_queue.get()
  101. try:
  102. request_logger.info('Processing request')
  103. response_data = await get_response_data(request_logger, resolve, request_data)
  104. # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError
  105. # https://stackoverflow.com/a/59794872/1319998
  106. sock.sendto(response_data, addr)
  107. except Exception:
  108. request_logger.exception('Error processing request')
  109. finally:
  110. request_logger.info('Finished processing request')
  111. upstream_queue.task_done()
  112. async def get_response_data(request_logger, resolve, request_data):
  113. # This may raise an exception, which is handled at a higher level.
  114. # We can't [and I suspect shouldn't try to] return an error to the
  115. # client, since we're not able to extract the QID, so the client won't
  116. # be able to match it with an outgoing request
  117. query = parse(request_data)
  118. try:
  119. return pack(
  120. error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
  121. (await proxy(request_logger, resolve, query))
  122. )
  123. except Exception:
  124. request_logger.exception('Failed to proxy %s', query)
  125. return pack(error(query, ERRORS.SERVFAIL))
  126. async def proxy(request_logger, resolve, query):
  127. name_bytes = query.qd[0].name
  128. request_logger.info('Name: %s', name_bytes)
  129. name_str_lower = query.qd[0].name.lower().decode('idna')
  130. request_logger.info('Decoded: %s', name_str_lower)
  131. for pattern, replace in rules:
  132. rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower)
  133. if num_matches:
  134. request_logger.info('Matches rule (%s, %s)', pattern, replace)
  135. break
  136. else:
  137. # No break was triggered, i.e. no match
  138. request_logger.info('Does not match a rule')
  139. return error(query, ERRORS.REFUSED)
  140. try:
  141. ip_addresses = await resolve(
  142. rewritten_name_str, TYPES.A,
  143. get_logger_adapter=get_resolver_logger_adapter(request_logger))
  144. except DnsRecordDoesNotExist:
  145. request_logger.info('Does not exist')
  146. return error(query, ERRORS.NXDOMAIN)
  147. except DnsResponseCode as dns_response_code_error:
  148. request_logger.info('Received error from upstream: %s',
  149. dns_response_code_error.args[0])
  150. return error(query, dns_response_code_error.args[0])
  151. request_logger.info('Resolved to %s', ip_addresses)
  152. now = loop.time()
  153. def ttl(ip_address):
  154. return int(max(0.0, ip_address.expires_at - now))
  155. reponse_records = tuple(
  156. ResourceRecord(name=name_bytes, qtype=TYPES.A,
  157. qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
  158. for ip_address in ip_addresses
  159. )
  160. return Message(
  161. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  162. qd=query.qd, an=reponse_records, ns=(), ar=(),
  163. )
  164. async def start():
  165. # The socket is created synchronously and passed to the server worker,
  166. # so if there is an error creating it, this function will raise an
  167. # exception. If no exeption is raise, we are indeed listening#
  168. sock = get_socket()
  169. # The resolver is also created synchronously, since it can parse
  170. # /etc/hosts or /etc/resolve.conf, and can raise an exception if
  171. # something goes wrong with that
  172. resolve, clear_cache = get_resolver()
  173. async def stop():
  174. sock.close()
  175. await clear_cache()
  176. return create_task(server_worker(sock, resolve, stop))
  177. return start
  178. def error(query, rcode):
  179. return Message(
  180. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  181. qd=query.qd, an=(), ns=(), ar=(),
  182. )