You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

278 lines
8.5 KiB

  1. from asyncio import (
  2. CancelledError,
  3. Future,
  4. Queue,
  5. create_task,
  6. get_running_loop,
  7. )
  8. from enum import (
  9. IntEnum,
  10. )
  11. import logging
  12. import re
  13. import socket
  14. from aiodnsresolver import (
  15. RESPONSE,
  16. TYPES,
  17. DnsRecordDoesNotExist,
  18. DnsResponseCode,
  19. Message,
  20. Resolver,
  21. ResourceRecord,
  22. pack,
  23. parse,
  24. )
  25. def get_socket_default():
  26. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  27. sock.setblocking(False)
  28. sock.bind(('', 53))
  29. return sock
  30. def get_resolver_default():
  31. return Resolver()
  32. def get_logger_default():
  33. return logging.getLogger('dnsrewriteproxy')
  34. def DnsProxy(
  35. get_resolver=get_resolver_default, get_logger=get_logger_default,
  36. get_socket=get_socket_default,
  37. num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000,
  38. rules=(),
  39. ):
  40. class ERRORS(IntEnum):
  41. FORMERR = 1
  42. SERVFAIL = 2
  43. NXDOMAIN = 3
  44. REFUSED = 5
  45. def __str__(self):
  46. return self.name
  47. loop = get_running_loop()
  48. logger = get_logger()
  49. # The "main" task of the server: it receives incoming requests and puts
  50. # them in a queue that is then fetched from and processed by the proxy
  51. # workers
  52. async def server_worker(sock, resolve):
  53. downstream_queue = Queue(maxsize=downstream_queue_maxsize)
  54. upstream_queue = Queue(maxsize=upstream_queue_maxsize)
  55. # It would "usually" be ok to send downstream from multiple tasks, but
  56. # if the socket has a full buffer, it would raise a BlockingIOError,
  57. # and we will need to attach a reader. We can only attach one reader
  58. # per underlying file, and since we have a single socket, we have a
  59. # single file. So we send downstream from a single task
  60. downstream_worker_task = create_task(downstream_worker(sock, downstream_queue))
  61. # We have multiple upstream workers to be able to send multiple
  62. # requests upstream concurrently, and add responses to downstream_queue
  63. upstream_worker_tasks = [
  64. create_task(upstream_worker(resolve, upstream_queue, downstream_queue))
  65. for _ in range(0, num_workers)]
  66. try:
  67. while True:
  68. request_data, addr = await recvfrom(loop, sock, 512)
  69. await upstream_queue.put((request_data, addr))
  70. finally:
  71. # Finish upstream requests, which can add to to the downstream
  72. # queue
  73. await upstream_queue.join()
  74. for upstream_task in upstream_worker_tasks:
  75. upstream_task.cancel()
  76. # Ensure we have sent the responses downstream
  77. await downstream_queue.join()
  78. downstream_worker_task.cancel()
  79. # Wait for the tasks to really be finished
  80. for upstream_task in upstream_worker_tasks:
  81. try:
  82. await upstream_task
  83. except Exception:
  84. pass
  85. try:
  86. await downstream_worker_task
  87. except Exception:
  88. pass
  89. async def upstream_worker(resolve, upstream_queue, downstream_queue):
  90. while True:
  91. request_data, addr = await upstream_queue.get()
  92. try:
  93. response_data = await get_response_data(resolve, request_data)
  94. except Exception:
  95. logger.exception('Exception from handler_request_data %s', addr)
  96. upstream_queue.task_done()
  97. continue
  98. await downstream_queue.put((response_data, addr))
  99. upstream_queue.task_done()
  100. async def downstream_worker(sock, downstream_queue):
  101. while True:
  102. response_data, addr = await downstream_queue.get()
  103. await sendto(loop, sock, response_data, addr)
  104. downstream_queue.task_done()
  105. async def get_response_data(resolve, request_data):
  106. # This may raise an exception, which is handled at a higher level.
  107. # We can't [and I suspect shouldn't try to] return an error to the
  108. # client, since we're not able to extract the QID, so the client won't
  109. # be able to match it with an outgoing request
  110. query = parse(request_data)
  111. if not query.qd:
  112. return pack(error(query, ERRORS.REFUSED))
  113. try:
  114. return pack(
  115. error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
  116. (await proxy(resolve, query))
  117. )
  118. except Exception:
  119. logger.exception('Failed to proxy %s', query)
  120. return pack(error(query, ERRORS.SERVFAIL))
  121. async def proxy(resolve, query):
  122. name_bytes = query.qd[0].name
  123. name_str = query.qd[0].name.decode('idna')
  124. for pattern, replace in rules:
  125. rewritten_name_str, num_matches = re.subn(pattern, replace, name_str)
  126. if num_matches:
  127. break
  128. else:
  129. # No break was triggered, i.e. no match
  130. return error(query, ERRORS.REFUSED)
  131. try:
  132. ip_addresses = await resolve(rewritten_name_str, TYPES.A)
  133. except DnsRecordDoesNotExist:
  134. return error(query, ERRORS.NXDOMAIN)
  135. except DnsResponseCode as dns_response_code_error:
  136. return error(query, dns_response_code_error.args[0])
  137. def ttl(ip_address):
  138. return int(max(0.0, ip_address.expires_at - loop.time()))
  139. reponse_records = tuple(
  140. ResourceRecord(name=name_bytes, qtype=TYPES.A,
  141. qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
  142. for ip_address in ip_addresses
  143. )
  144. return Message(
  145. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  146. qd=query.qd, an=reponse_records, ns=(), ar=(),
  147. )
  148. async def start():
  149. # The socket is created synchronously and passed to the server worker,
  150. # so if there is an error creating it, this function will raise an
  151. # exception. If no exeption is raise, we are indeed listening#
  152. sock = get_socket()
  153. # The resolver is also created synchronously, since it can parse
  154. # /etc/hosts or /etc/resolve.conf, and can raise an exception if
  155. # something goes wrong with that
  156. resolve, clear_cache = get_resolver()
  157. server_worker_task = create_task(server_worker(sock, resolve))
  158. async def stop():
  159. server_worker_task.cancel()
  160. try:
  161. await server_worker_task
  162. except CancelledError:
  163. pass
  164. sock.close()
  165. await clear_cache()
  166. return stop
  167. return start
  168. def error(query, rcode):
  169. return Message(
  170. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  171. qd=query.qd, an=(), ns=(), ar=(),
  172. )
  173. async def recvfrom(loop, sock, max_bytes):
  174. try:
  175. return sock.recvfrom(max_bytes)
  176. except BlockingIOError:
  177. pass
  178. def reader():
  179. try:
  180. (data, addr) = sock.recvfrom(max_bytes)
  181. except BlockingIOError:
  182. pass
  183. except BaseException as exception:
  184. loop.remove_reader(fileno)
  185. if not result.done():
  186. result.set_exception(exception)
  187. else:
  188. loop.remove_reader(fileno)
  189. if not result.done():
  190. result.set_result((data, addr))
  191. fileno = sock.fileno()
  192. result = Future()
  193. loop.add_reader(fileno, reader)
  194. try:
  195. return await result
  196. finally:
  197. loop.remove_reader(fileno)
  198. async def sendto(loop, sock, data, addr):
  199. # In our cases, the UDP responses will always be 512 bytes or less.
  200. # Even if sendto sent some of the data, there is no way for the other
  201. # end to reconstruct their order, so we don't include any logic to send
  202. # the rest of the data. Since it's UDP, the client already has to have
  203. # retry logic
  204. try:
  205. return sock.sendto(data, addr)
  206. except BlockingIOError:
  207. pass
  208. def writer():
  209. try:
  210. num_bytes = sock.sendto(data, addr)
  211. except BlockingIOError:
  212. pass
  213. except BaseException as exception:
  214. loop.remove_writer(fileno)
  215. if not result.done():
  216. result.set_exception(exception)
  217. else:
  218. loop.remove_writer(fileno)
  219. if not result.done():
  220. result.set_result(num_bytes)
  221. fileno = sock.fileno()
  222. result = Future()
  223. loop.add_writer(fileno, writer)
  224. try:
  225. return await result
  226. finally:
  227. loop.remove_writer(fileno)