You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

226 lines
7.5 KiB

  1. from asyncio import (
  2. CancelledError,
  3. Queue,
  4. create_task,
  5. get_running_loop,
  6. )
  7. from enum import (
  8. IntEnum,
  9. )
  10. import logging
  11. import re
  12. import socket
  13. from aiodnsresolver import (
  14. RESPONSE,
  15. TYPES,
  16. DnsRecordDoesNotExist,
  17. DnsResponseCode,
  18. Message,
  19. Resolver,
  20. ResourceRecord,
  21. pack,
  22. parse,
  23. recvfrom,
  24. )
  25. def get_socket_default():
  26. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  27. sock.setblocking(False)
  28. sock.bind(('', 53))
  29. return sock
  30. def get_resolver_default():
  31. return Resolver()
  32. def get_logger_default():
  33. return logging.getLogger('dnsrewriteproxy')
  34. def DnsProxy(
  35. get_resolver=get_resolver_default, get_logger=get_logger_default,
  36. get_socket=get_socket_default,
  37. num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000,
  38. rules=(),
  39. ):
  40. class ERRORS(IntEnum):
  41. FORMERR = 1
  42. SERVFAIL = 2
  43. NXDOMAIN = 3
  44. REFUSED = 5
  45. loop = get_running_loop()
  46. logger = get_logger()
  47. # The "main" task of the server: it receives incoming requests and puts
  48. # them in a queue that is then fetched from and processed by the proxy
  49. # workers
  50. async def server_worker(sock, resolve):
  51. downstream_queue = Queue(maxsize=downstream_queue_maxsize)
  52. upstream_queue = Queue(maxsize=upstream_queue_maxsize)
  53. # It would "usually" be ok to send downstream from multiple tasks, but
  54. # if the socket has a full buffer, it would raise a BlockingIOError,
  55. # and we will need to attach a reader. We can only attach one reader
  56. # per underlying file, and since we have a single socket, we have a
  57. # single file. So we send downstream from a single task
  58. downstream_worker_task = create_task(downstream_worker(sock, downstream_queue))
  59. # We have multiple upstream workers to be able to send multiple
  60. # requests upstream concurrently, and add responses to downstream_queue
  61. upstream_worker_tasks = [
  62. create_task(upstream_worker(resolve, upstream_queue, downstream_queue))
  63. for _ in range(0, num_workers)]
  64. try:
  65. while True:
  66. request_data, addr = await recvfrom(loop, [sock], 512)
  67. await upstream_queue.put((request_data, addr))
  68. finally:
  69. # Finish upstream requests, which can add to to the downstream
  70. # queue
  71. await upstream_queue.join()
  72. for upstream_task in upstream_worker_tasks:
  73. upstream_task.cancel()
  74. # Ensure we have sent the responses downstream
  75. await downstream_queue.join()
  76. downstream_worker_task.cancel()
  77. # Wait for the tasks to really be finished
  78. for upstream_task in upstream_worker_tasks:
  79. try:
  80. await upstream_task
  81. except Exception:
  82. pass
  83. try:
  84. await downstream_worker_task
  85. except Exception:
  86. pass
  87. async def upstream_worker(resolve, upstream_queue, downstream_queue):
  88. while True:
  89. request_data, addr = await upstream_queue.get()
  90. try:
  91. response_data = await get_response_data(resolve, request_data)
  92. except Exception:
  93. logger.exception('Exception from handler_request_data %s', addr)
  94. upstream_queue.task_done()
  95. continue
  96. await downstream_queue.put((response_data, addr))
  97. upstream_queue.task_done()
  98. async def downstream_worker(sock, downstream_queue):
  99. while True:
  100. response_data, addr = await downstream_queue.get()
  101. try:
  102. await sendto(sock, response_data, addr)
  103. except Exception:
  104. logger.exception('Unable to send response to %s', addr)
  105. downstream_queue.task_done()
  106. async def get_response_data(resolve, request_data):
  107. # This may raise an exception, which is handled at a higher level.
  108. # We can't [and I suspect shouldn't try to] return an error to the
  109. # client, since we're not able to extract the QID, so the client won't
  110. # be able to match it with an outgoing request
  111. query = parse(request_data)
  112. if not query.qd:
  113. return pack(error(query, ERRORS.REFUSED))
  114. try:
  115. return pack(
  116. error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
  117. (await proxy(resolve, query))
  118. )
  119. except Exception:
  120. logger.exception('Failed to proxy %s', query)
  121. return pack(error(query, ERRORS.SERVFAIL))
  122. async def proxy(resolve, query):
  123. name_bytes = query.qd[0].name
  124. name_str = query.qd[0].name.decode('idna')
  125. for pattern, replace in rules:
  126. rewritten_name_str, num_matches = re.subn(pattern, replace, name_str)
  127. if num_matches:
  128. break
  129. else:
  130. # No break was triggered, i.e. no match
  131. return error(query, ERRORS.REFUSED)
  132. try:
  133. ip_addresses = await resolve(rewritten_name_str, TYPES.A)
  134. except DnsRecordDoesNotExist:
  135. return error(query, ERRORS.NXDOMAIN)
  136. except DnsResponseCode as dns_response_code_error:
  137. return error(query, dns_response_code_error.args[0])
  138. def ttl(ip_address):
  139. return int(max(0.0, ip_address.expires_at - loop.time()))
  140. reponse_records = tuple(
  141. ResourceRecord(name=name_bytes, qtype=TYPES.A,
  142. qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
  143. for ip_address in ip_addresses
  144. )
  145. return Message(
  146. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  147. qd=query.qd, an=reponse_records, ns=(), ar=(),
  148. )
  149. async def start():
  150. # The socket is created synchronously and passed to the server worker,
  151. # so if there is an error creating it, this function will raise an
  152. # exception. If no exeption is raise, we are indeed listening#
  153. sock = get_socket()
  154. # The resolver is also created synchronously, since it can parse
  155. # /etc/hosts or /etc/resolve.conf, and can raise an exception if
  156. # something goes wrong with that
  157. resolve, clear_cache = get_resolver()
  158. server_worker_task = create_task(server_worker(sock, resolve))
  159. async def stop():
  160. server_worker_task.cancel()
  161. try:
  162. await server_worker_task
  163. except CancelledError:
  164. pass
  165. sock.close()
  166. await clear_cache()
  167. return stop
  168. return start
  169. def error(query, rcode):
  170. return Message(
  171. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  172. qd=query.qd, an=(), ns=(), ar=(),
  173. )
  174. async def sendto(sock, data, addr):
  175. # In our cases, the UDP responses will always be 512 bytes or less.
  176. # Even if sendto sent some of the data, there is no way for the other
  177. # end to reconstruct their order, so we don't include any logic to send
  178. # the rest of the data. Since it's UDP, the client already has to have
  179. # retry logic.
  180. #
  181. # Potentially also, this can raise a BlockingIOError, but even trying
  182. # to force high numbers of messages with a small socket buffer, this has
  183. # never been observed. As above, the client must have retry logic, so we
  184. # leave it to the client to deal with this.
  185. return sock.sendto(data, addr)