You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
7.1 KiB

  1. import asyncio
  2. import socket
  3. import unittest
  4. from aiodnsresolver import (
  5. RESPONSE,
  6. TYPES,
  7. DnsRecordDoesNotExist,
  8. DnsResponseCode,
  9. IPv4AddressExpiresAt,
  10. Message,
  11. Resolver,
  12. pack,
  13. parse,
  14. recvfrom,
  15. )
  16. from dnsrewriteproxy import (
  17. DnsProxy,
  18. )
  19. def async_test(func):
  20. def wrapper(*args, **kwargs):
  21. future = func(*args, **kwargs)
  22. loop = asyncio.get_event_loop()
  23. loop.run_until_complete(future)
  24. return wrapper
  25. class TestProxy(unittest.TestCase):
  26. def add_async_cleanup(self, coroutine):
  27. self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())
  28. @async_test
  29. async def test_e2e_no_match_rule(self):
  30. resolve, clear_cache = get_resolver(3535)
  31. self.add_async_cleanup(clear_cache)
  32. start = DnsProxy(get_socket=get_socket(3535))
  33. stop = await start()
  34. self.add_async_cleanup(stop)
  35. with self.assertRaises(DnsResponseCode) as cm:
  36. await resolve('www.google.com', TYPES.A)
  37. self.assertEqual(cm.exception.args[0], 5)
  38. @async_test
  39. async def test_e2e_match_all(self):
  40. resolve, clear_cache = get_resolver(3535)
  41. self.add_async_cleanup(clear_cache)
  42. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  43. stop = await start()
  44. self.add_async_cleanup(stop)
  45. response = await resolve('www.google.com', TYPES.A)
  46. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  47. @async_test
  48. async def test_e2e_default_port_match_all(self):
  49. resolve, clear_cache = get_resolver(53)
  50. self.add_async_cleanup(clear_cache)
  51. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  52. stop = await start()
  53. self.add_async_cleanup(stop)
  54. response = await resolve('www.google.com', TYPES.A)
  55. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  56. @async_test
  57. async def test_e2e_default_resolver_match_all_non_existing_domain(self):
  58. resolve, clear_cache = get_resolver(53)
  59. self.add_async_cleanup(clear_cache)
  60. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  61. stop = await start()
  62. self.add_async_cleanup(stop)
  63. with self.assertRaises(DnsRecordDoesNotExist):
  64. await resolve('doesnotexist.charemza.name', TYPES.A)
  65. @async_test
  66. async def test_e2e_default_resolver_match_all_bad_upstream(self):
  67. resolve, clear_cache = get_resolver(53, timeout=100)
  68. self.add_async_cleanup(clear_cache)
  69. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  70. stop = await start()
  71. self.add_async_cleanup(stop)
  72. with self.assertRaises(DnsResponseCode) as cm:
  73. await resolve('www.google.com', TYPES.A)
  74. self.assertEqual(cm.exception.args[0], 2)
  75. @async_test
  76. async def test_e2e_default_resolver_match_none_non_existing_domain(self):
  77. resolve, clear_cache = get_resolver(53)
  78. self.add_async_cleanup(clear_cache)
  79. start = DnsProxy()
  80. stop = await start()
  81. self.add_async_cleanup(stop)
  82. with self.assertRaises(DnsResponseCode) as cm:
  83. await resolve('doesnotexist.charemza.name', TYPES.A)
  84. self.assertEqual(cm.exception.args[0], 5)
  85. @async_test
  86. async def test_many_responses_with_small_socket_buffer(self):
  87. resolve, clear_cache = get_resolver(53)
  88. self.add_async_cleanup(clear_cache)
  89. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
  90. get_resolver=get_fixed_resolver)
  91. stop = await start()
  92. self.add_async_cleanup(stop)
  93. tasks = [
  94. asyncio.create_task(resolve('www.google.com', TYPES.A))
  95. for _ in range(0, 100000)
  96. ]
  97. responses = await asyncio.gather(*tasks)
  98. for response in responses:
  99. self.assertEqual(str(response[0]), '1.2.3.4')
  100. bing_responses = await resolve('www.bing.com', TYPES.A)
  101. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  102. @async_test
  103. async def test_proxy_returns_error_from_upstream(self):
  104. rcode = 4
  105. async def get_response(query_data):
  106. query = parse(query_data)
  107. response = Message(
  108. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  109. qd=query.qd, an=(), ns=(), ar=(),
  110. )
  111. return pack(response)
  112. stop_nameserver = await start_nameserver(54, get_response)
  113. self.add_async_cleanup(stop_nameserver)
  114. resolve, clear_cache = get_resolver(53)
  115. self.add_async_cleanup(clear_cache)
  116. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  117. stop = await start()
  118. self.add_async_cleanup(stop)
  119. with self.assertRaises(DnsResponseCode) as cm:
  120. await resolve('www.google.com', TYPES.A)
  121. self.assertEqual(cm.exception.args[0], 4)
  122. rcode = 5
  123. with self.assertRaises(DnsResponseCode) as cm:
  124. await resolve('www.google.com', TYPES.A)
  125. self.assertEqual(cm.exception.args[0], 5)
  126. def get_socket(port):
  127. def _get_socket():
  128. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  129. sock.setblocking(False)
  130. sock.bind(('', port))
  131. return sock
  132. return _get_socket
  133. def get_small_socket():
  134. # For linux, the minimum buffer size is 1024
  135. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  136. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
  137. sock.setblocking(False)
  138. sock.bind(('', 53))
  139. return sock
  140. def get_resolver(port, timeout=0.5):
  141. async def get_nameservers(_, __):
  142. for _ in range(0, 5):
  143. yield (timeout, ('127.0.0.1', port))
  144. return Resolver(get_nameservers=get_nameservers)
  145. def get_fixed_resolver():
  146. async def get_host(_, fqdn, qtype):
  147. hosts = {
  148. b'www.google.com': {
  149. TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
  150. },
  151. }
  152. try:
  153. return hosts[fqdn.lower()][qtype]
  154. except KeyError:
  155. return None
  156. return Resolver(get_host=get_host)
  157. async def start_nameserver(port, get_response):
  158. # For some tests we need to control the responses from upstream, especially in the cases
  159. # where it's not behaving
  160. loop = asyncio.get_event_loop()
  161. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  162. sock.setblocking(False)
  163. sock.bind(('', port))
  164. async def server():
  165. client_tasks = []
  166. try:
  167. while True:
  168. data, addr = await recvfrom(loop, [sock], 512)
  169. client_tasks.append(asyncio.ensure_future(client_task(data, addr)))
  170. finally:
  171. for task in client_tasks:
  172. task.cancel()
  173. async def client_task(data, addr):
  174. response = await get_response(data)
  175. sock.sendto(response, addr)
  176. server_task = asyncio.ensure_future(server())
  177. async def stop():
  178. server_task.cancel()
  179. await asyncio.sleep(0)
  180. sock.close()
  181. return stop