You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

286 lines
8.7 KiB

  1. from asyncio import (
  2. CancelledError,
  3. Future,
  4. Queue,
  5. create_task,
  6. get_running_loop,
  7. )
  8. from enum import (
  9. IntEnum,
  10. )
  11. import logging
  12. import re
  13. import socket
  14. from aiodnsresolver import (
  15. RESPONSE,
  16. TYPES,
  17. DnsRecordDoesNotExist,
  18. DnsResponseCode,
  19. Message,
  20. Resolver,
  21. ResourceRecord,
  22. pack,
  23. parse,
  24. )
  25. def get_socket_default():
  26. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  27. sock.setblocking(False)
  28. sock.bind(('', 53))
  29. return sock
  30. def get_resolver_default():
  31. return Resolver()
  32. def get_resolver_with_upstream(upstream):
  33. async def get_nameservers(_, __):
  34. for _ in range(0, 5):
  35. yield (0.5, (upstream, 53))
  36. return Resolver(get_nameservers=get_nameservers)
  37. def get_logger_default():
  38. return logging.getLogger('dnsrewriteproxy')
  39. def DnsProxy(
  40. get_resolver=get_resolver_default, get_logger=get_logger_default,
  41. get_socket=get_socket_default,
  42. num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000,
  43. rules=(),
  44. ):
  45. class ERRORS(IntEnum):
  46. FORMERR = 1
  47. SERVFAIL = 2
  48. NXDOMAIN = 3
  49. REFUSED = 5
  50. def __str__(self):
  51. return self.name
  52. loop = get_running_loop()
  53. logger = get_logger()
  54. # The "main" task of the server: it receives incoming requests and puts
  55. # them in a queue that is then fetched from and processed by the proxy
  56. # workers
  57. async def server_worker(sock, resolve):
  58. downstream_queue = Queue(maxsize=downstream_queue_maxsize)
  59. upstream_queue = Queue(maxsize=upstream_queue_maxsize)
  60. # It would "usually" be ok to send downstream from multiple tasks, but
  61. # if the socket has a full buffer, it would raise a BlockingIOError,
  62. # and we will need to attach a reader. We can only attach one reader
  63. # per underlying file, and since we have a single socket, we have a
  64. # single file. So we send downstream from a single task
  65. downstream_worker_task = create_task(downstream_worker(sock, downstream_queue))
  66. # We have multiple upstream workers to be able to send multiple
  67. # requests upstream concurrently, and add responses to downstream_queue
  68. upstream_worker_tasks = [
  69. create_task(upstream_worker(resolve, upstream_queue, downstream_queue))
  70. for _ in range(0, num_workers)]
  71. try:
  72. while True:
  73. request_data, addr = await recvfrom(loop, sock, 512)
  74. await upstream_queue.put((request_data, addr))
  75. finally:
  76. # Finish upstream requests, which can add to to the downstream
  77. # queue
  78. await upstream_queue.join()
  79. for upstream_task in upstream_worker_tasks:
  80. upstream_task.cancel()
  81. # Ensure we have sent the responses downstream
  82. await downstream_queue.join()
  83. downstream_worker_task.cancel()
  84. # Wait for the tasks to really be finished
  85. for upstream_task in upstream_worker_tasks:
  86. try:
  87. await upstream_task
  88. except Exception:
  89. pass
  90. try:
  91. await downstream_worker_task
  92. except Exception:
  93. pass
  94. async def upstream_worker(resolve, upstream_queue, downstream_queue):
  95. while True:
  96. request_data, addr = await upstream_queue.get()
  97. try:
  98. response_data = await get_response_data(resolve, request_data)
  99. except Exception:
  100. logger.exception('Exception from handler_request_data %s', addr)
  101. upstream_queue.task_done()
  102. continue
  103. await downstream_queue.put((response_data, addr))
  104. upstream_queue.task_done()
  105. async def downstream_worker(sock, downstream_queue):
  106. while True:
  107. response_data, addr = await downstream_queue.get()
  108. await sendto(loop, sock, response_data, addr)
  109. downstream_queue.task_done()
  110. async def get_response_data(resolve, request_data):
  111. # This may raise an exception, which is handled at a higher level.
  112. # We can't [and I suspect shouldn't try to] return an error to the
  113. # client, since we're not able to extract the QID, so the client won't
  114. # be able to match it with an outgoing request
  115. query = parse(request_data)
  116. if not query.qd:
  117. return pack(error(query, ERRORS.REFUSED))
  118. try:
  119. return pack(
  120. error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
  121. (await proxy(resolve, query))
  122. )
  123. except Exception:
  124. logger.exception('Failed to proxy %s', query)
  125. return pack(error(query, ERRORS.SERVFAIL))
  126. async def proxy(resolve, query):
  127. name_bytes = query.qd[0].name
  128. name_str = query.qd[0].name.decode('idna')
  129. for pattern, replace in rules:
  130. rewritten_name_str, num_matches = re.subn(pattern, replace, name_str)
  131. if num_matches:
  132. break
  133. else:
  134. # No break was triggered, i.e. no match
  135. return error(query, ERRORS.REFUSED)
  136. try:
  137. ip_addresses = await resolve(rewritten_name_str, TYPES.A)
  138. except DnsRecordDoesNotExist:
  139. return error(query, ERRORS.NXDOMAIN)
  140. except DnsResponseCode as dns_response_code_error:
  141. return error(query, dns_response_code_error.args[0])
  142. def ttl(ip_address):
  143. return int(max(0.0, ip_address.expires_at - loop.time()))
  144. reponse_records = tuple(
  145. ResourceRecord(name=name_bytes, qtype=TYPES.A,
  146. qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
  147. for ip_address in ip_addresses
  148. )
  149. return Message(
  150. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  151. qd=query.qd, an=reponse_records, ns=(), ar=(),
  152. )
  153. async def start():
  154. # The socket is created synchronously and passed to the server worker,
  155. # so if there is an error creating it, this function will raise an
  156. # exception. If no exeption is raise, we are indeed listening#
  157. sock = get_socket()
  158. # The resolver is also created synchronously, since it can parse
  159. # /etc/hosts or /etc/resolve.conf, and can raise an exception if
  160. # something goes wrong with that
  161. resolve, clear_cache = get_resolver()
  162. server_worker_task = create_task(server_worker(sock, resolve))
  163. async def stop():
  164. server_worker_task.cancel()
  165. try:
  166. await server_worker_task
  167. except CancelledError:
  168. pass
  169. sock.close()
  170. await clear_cache()
  171. return stop
  172. return start
  173. def error(query, rcode):
  174. return Message(
  175. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  176. qd=query.qd, an=(), ns=(), ar=(),
  177. )
  178. async def recvfrom(loop, sock, max_bytes):
  179. try:
  180. return sock.recvfrom(max_bytes)
  181. except BlockingIOError:
  182. pass
  183. def reader():
  184. try:
  185. (data, addr) = sock.recvfrom(max_bytes)
  186. except BlockingIOError:
  187. pass
  188. except BaseException as exception:
  189. loop.remove_reader(fileno)
  190. if not result.done():
  191. result.set_exception(exception)
  192. else:
  193. loop.remove_reader(fileno)
  194. if not result.done():
  195. result.set_result((data, addr))
  196. fileno = sock.fileno()
  197. result = Future()
  198. loop.add_reader(fileno, reader)
  199. try:
  200. return await result
  201. finally:
  202. loop.remove_reader(fileno)
  203. async def sendto(loop, sock, data, addr):
  204. # In our cases, the UDP responses will always be 512 bytes or less.
  205. # Even if sendto sent some of the data, there is no way for the other
  206. # end to reconstruct their order, so we don't include any logic to send
  207. # the rest of the data. Since it's UDP, the client already has to have
  208. # retry logic
  209. try:
  210. return sock.sendto(data, addr)
  211. except BlockingIOError:
  212. pass
  213. def writer():
  214. try:
  215. num_bytes = sock.sendto(data, addr)
  216. except BlockingIOError:
  217. pass
  218. except BaseException as exception:
  219. loop.remove_writer(fileno)
  220. if not result.done():
  221. result.set_exception(exception)
  222. else:
  223. loop.remove_writer(fileno)
  224. if not result.done():
  225. result.set_result(num_bytes)
  226. fileno = sock.fileno()
  227. result = Future()
  228. loop.add_writer(fileno, writer)
  229. try:
  230. return await result
  231. finally:
  232. loop.remove_writer(fileno)