You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

246 lines
7.8 KiB

  1. from asyncio import (
  2. CancelledError,
  3. Future,
  4. Queue,
  5. create_task,
  6. get_running_loop,
  7. )
  8. from enum import (
  9. IntEnum,
  10. )
  11. import logging
  12. import re
  13. import socket
  14. from aiodnsresolver import (
  15. RESPONSE,
  16. TYPES,
  17. DnsRecordDoesNotExist,
  18. DnsResponseCode,
  19. Message,
  20. Resolver,
  21. ResourceRecord,
  22. pack,
  23. parse,
  24. recvfrom,
  25. )
  26. def get_socket_default():
  27. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  28. sock.setblocking(False)
  29. sock.bind(('', 53))
  30. return sock
  31. def get_resolver_default():
  32. return Resolver()
  33. def get_logger_default():
  34. return logging.getLogger('dnsrewriteproxy')
  35. def DnsProxy(
  36. get_resolver=get_resolver_default, get_logger=get_logger_default,
  37. get_socket=get_socket_default,
  38. num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000,
  39. rules=(),
  40. ):
  41. class ERRORS(IntEnum):
  42. FORMERR = 1
  43. SERVFAIL = 2
  44. NXDOMAIN = 3
  45. REFUSED = 5
  46. loop = get_running_loop()
  47. logger = get_logger()
  48. # The "main" task of the server: it receives incoming requests and puts
  49. # them in a queue that is then fetched from and processed by the proxy
  50. # workers
  51. async def server_worker(sock, resolve):
  52. downstream_queue = Queue(maxsize=downstream_queue_maxsize)
  53. upstream_queue = Queue(maxsize=upstream_queue_maxsize)
  54. # It would "usually" be ok to send downstream from multiple tasks, but
  55. # if the socket has a full buffer, it would raise a BlockingIOError,
  56. # and we will need to attach a reader. We can only attach one reader
  57. # per underlying file, and since we have a single socket, we have a
  58. # single file. So we send downstream from a single task
  59. downstream_worker_task = create_task(downstream_worker(sock, downstream_queue))
  60. # We have multiple upstream workers to be able to send multiple
  61. # requests upstream concurrently, and add responses to downstream_queue
  62. upstream_worker_tasks = [
  63. create_task(upstream_worker(resolve, upstream_queue, downstream_queue))
  64. for _ in range(0, num_workers)]
  65. try:
  66. while True:
  67. request_data, addr = await recvfrom(loop, [sock], 512)
  68. await upstream_queue.put((request_data, addr))
  69. finally:
  70. # Finish upstream requests, which can add to to the downstream
  71. # queue
  72. await upstream_queue.join()
  73. for upstream_task in upstream_worker_tasks:
  74. upstream_task.cancel()
  75. # Ensure we have sent the responses downstream
  76. await downstream_queue.join()
  77. downstream_worker_task.cancel()
  78. # Wait for the tasks to really be finished
  79. for upstream_task in upstream_worker_tasks:
  80. try:
  81. await upstream_task
  82. except Exception:
  83. pass
  84. try:
  85. await downstream_worker_task
  86. except Exception:
  87. pass
  88. async def upstream_worker(resolve, upstream_queue, downstream_queue):
  89. while True:
  90. request_data, addr = await upstream_queue.get()
  91. try:
  92. response_data = await get_response_data(resolve, request_data)
  93. except Exception:
  94. logger.exception('Exception from handler_request_data %s', addr)
  95. upstream_queue.task_done()
  96. continue
  97. await downstream_queue.put((response_data, addr))
  98. upstream_queue.task_done()
  99. async def downstream_worker(sock, downstream_queue):
  100. while True:
  101. response_data, addr = await downstream_queue.get()
  102. await sendto(loop, sock, response_data, addr)
  103. downstream_queue.task_done()
  104. async def get_response_data(resolve, request_data):
  105. # This may raise an exception, which is handled at a higher level.
  106. # We can't [and I suspect shouldn't try to] return an error to the
  107. # client, since we're not able to extract the QID, so the client won't
  108. # be able to match it with an outgoing request
  109. query = parse(request_data)
  110. if not query.qd:
  111. return pack(error(query, ERRORS.REFUSED))
  112. try:
  113. return pack(
  114. error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
  115. (await proxy(resolve, query))
  116. )
  117. except Exception:
  118. logger.exception('Failed to proxy %s', query)
  119. return pack(error(query, ERRORS.SERVFAIL))
  120. async def proxy(resolve, query):
  121. name_bytes = query.qd[0].name
  122. name_str = query.qd[0].name.decode('idna')
  123. for pattern, replace in rules:
  124. rewritten_name_str, num_matches = re.subn(pattern, replace, name_str)
  125. if num_matches:
  126. break
  127. else:
  128. # No break was triggered, i.e. no match
  129. return error(query, ERRORS.REFUSED)
  130. try:
  131. ip_addresses = await resolve(rewritten_name_str, TYPES.A)
  132. except DnsRecordDoesNotExist:
  133. return error(query, ERRORS.NXDOMAIN)
  134. except DnsResponseCode as dns_response_code_error:
  135. return error(query, dns_response_code_error.args[0])
  136. def ttl(ip_address):
  137. return int(max(0.0, ip_address.expires_at - loop.time()))
  138. reponse_records = tuple(
  139. ResourceRecord(name=name_bytes, qtype=TYPES.A,
  140. qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
  141. for ip_address in ip_addresses
  142. )
  143. return Message(
  144. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  145. qd=query.qd, an=reponse_records, ns=(), ar=(),
  146. )
  147. async def start():
  148. # The socket is created synchronously and passed to the server worker,
  149. # so if there is an error creating it, this function will raise an
  150. # exception. If no exeption is raise, we are indeed listening#
  151. sock = get_socket()
  152. # The resolver is also created synchronously, since it can parse
  153. # /etc/hosts or /etc/resolve.conf, and can raise an exception if
  154. # something goes wrong with that
  155. resolve, clear_cache = get_resolver()
  156. server_worker_task = create_task(server_worker(sock, resolve))
  157. async def stop():
  158. server_worker_task.cancel()
  159. try:
  160. await server_worker_task
  161. except CancelledError:
  162. pass
  163. sock.close()
  164. await clear_cache()
  165. return stop
  166. return start
  167. def error(query, rcode):
  168. return Message(
  169. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  170. qd=query.qd, an=(), ns=(), ar=(),
  171. )
  172. async def sendto(loop, sock, data, addr):
  173. # In our cases, the UDP responses will always be 512 bytes or less.
  174. # Even if sendto sent some of the data, there is no way for the other
  175. # end to reconstruct their order, so we don't include any logic to send
  176. # the rest of the data. Since it's UDP, the client already has to have
  177. # retry logic
  178. try:
  179. return sock.sendto(data, addr)
  180. except BlockingIOError:
  181. pass
  182. def writer():
  183. try:
  184. num_bytes = sock.sendto(data, addr)
  185. except BlockingIOError:
  186. pass
  187. except BaseException as exception:
  188. loop.remove_writer(fileno)
  189. if not result.done():
  190. result.set_exception(exception)
  191. else:
  192. loop.remove_writer(fileno)
  193. if not result.done():
  194. result.set_result(num_bytes)
  195. fileno = sock.fileno()
  196. result = Future()
  197. loop.add_writer(fileno, writer)
  198. try:
  199. return await result
  200. finally:
  201. loop.remove_writer(fileno)