You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

398 lines
13 KiB

  1. import asyncio
  2. import ipaddress
  3. import socket
  4. import struct
  5. import unittest
  6. from aiodnsresolver import (
  7. RESPONSE,
  8. TYPES,
  9. DnsRecordDoesNotExist,
  10. DnsResponseCode,
  11. DnsTimeout,
  12. IPv4AddressExpiresAt,
  13. Message,
  14. ResourceRecord,
  15. QuestionRecord,
  16. Resolver,
  17. pack,
  18. parse,
  19. recvfrom,
  20. )
  21. from dnsrewriteproxy import (
  22. DnsProxy,
  23. )
  24. def async_test(func):
  25. def wrapper(*args, **kwargs):
  26. future = func(*args, **kwargs)
  27. loop = asyncio.get_event_loop()
  28. loop.run_until_complete(future)
  29. return wrapper
  30. class TestProxy(unittest.TestCase):
  31. def add_async_cleanup(self, coroutine):
  32. self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())
  33. @async_test
  34. async def test_e2e_no_match_rule(self):
  35. resolve, clear_cache = get_resolver(3535)
  36. self.add_async_cleanup(clear_cache)
  37. start = DnsProxy(get_socket=get_socket(3535))
  38. stop = await start()
  39. self.add_async_cleanup(stop)
  40. with self.assertRaises(DnsResponseCode) as cm:
  41. await resolve('www.google.com', TYPES.A)
  42. self.assertEqual(cm.exception.args[0], 5)
  43. @async_test
  44. async def test_e2e_match_all(self):
  45. resolve, clear_cache = get_resolver(3535)
  46. self.add_async_cleanup(clear_cache)
  47. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  48. stop = await start()
  49. self.add_async_cleanup(stop)
  50. response = await resolve('www.google.com', TYPES.A)
  51. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  52. @async_test
  53. async def test_e2e_default_port_match_all(self):
  54. resolve, clear_cache = get_resolver(53)
  55. self.add_async_cleanup(clear_cache)
  56. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  57. stop = await start()
  58. self.add_async_cleanup(stop)
  59. response = await resolve('www.google.com', TYPES.A)
  60. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  61. @async_test
  62. async def test_e2e_default_resolver_match_all_non_existing_domain(self):
  63. resolve, clear_cache = get_resolver(53)
  64. self.add_async_cleanup(clear_cache)
  65. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  66. stop = await start()
  67. self.add_async_cleanup(stop)
  68. with self.assertRaises(DnsRecordDoesNotExist):
  69. await resolve('doesnotexist.charemza.name', TYPES.A)
  70. @async_test
  71. async def test_e2e_default_resolver_match_all_bad_upstream(self):
  72. resolve, clear_cache = get_resolver(53, timeout=100)
  73. self.add_async_cleanup(clear_cache)
  74. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  75. stop = await start()
  76. self.add_async_cleanup(stop)
  77. with self.assertRaises(DnsResponseCode) as cm:
  78. await resolve('www.google.com', TYPES.A)
  79. self.assertEqual(cm.exception.args[0], 2)
  80. @async_test
  81. async def test_e2e_default_resolver_match_none_non_existing_domain(self):
  82. resolve, clear_cache = get_resolver(53)
  83. self.add_async_cleanup(clear_cache)
  84. start = DnsProxy()
  85. stop = await start()
  86. self.add_async_cleanup(stop)
  87. with self.assertRaises(DnsResponseCode) as cm:
  88. await resolve('doesnotexist.charemza.name', TYPES.A)
  89. self.assertEqual(cm.exception.args[0], 5)
  90. @async_test
  91. async def test_many_responses_with_small_socket_buffer(self):
  92. resolve, clear_cache = get_resolver(53)
  93. self.add_async_cleanup(clear_cache)
  94. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
  95. get_resolver=get_fixed_resolver)
  96. stop = await start()
  97. self.add_async_cleanup(stop)
  98. tasks = [
  99. asyncio.create_task(resolve('www.google.com', TYPES.A))
  100. for _ in range(0, 100000)
  101. ]
  102. responses = await asyncio.gather(*tasks)
  103. for response in responses:
  104. self.assertEqual(str(response[0]), '1.2.3.4')
  105. bing_responses = await resolve('www.bing.com', TYPES.A)
  106. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  107. @async_test
  108. async def test_proxy_returns_error_from_upstream(self):
  109. rcode = 4
  110. async def get_response(query_data):
  111. query = parse(query_data)
  112. response = Message(
  113. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  114. qd=query.qd, an=(), ns=(), ar=(),
  115. )
  116. return pack(response)
  117. stop_nameserver = await start_nameserver(54, get_response)
  118. self.add_async_cleanup(stop_nameserver)
  119. resolve, clear_cache = get_resolver(53)
  120. self.add_async_cleanup(clear_cache)
  121. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  122. stop = await start()
  123. self.add_async_cleanup(stop)
  124. with self.assertRaises(DnsResponseCode) as cm:
  125. await resolve('www.google.com', TYPES.A)
  126. self.assertEqual(cm.exception.args[0], 4)
  127. rcode = 5
  128. with self.assertRaises(DnsResponseCode) as cm:
  129. await resolve('www.google.com', TYPES.A)
  130. self.assertEqual(cm.exception.args[0], 5)
  131. @async_test
  132. async def test_sending_bad_messages_not_affect_later_queries_a(self):
  133. resolve, clear_cache = get_resolver(53)
  134. self.add_async_cleanup(clear_cache)
  135. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  136. stop = await start()
  137. self.add_async_cleanup(stop)
  138. response = await resolve('www.google.com', TYPES.A)
  139. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  140. for _ in range(0, 100000):
  141. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  142. sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53))
  143. sock.close()
  144. tasks = [
  145. asyncio.create_task(resolve('www.google.com', TYPES.A))
  146. for _ in range(0, 100000)
  147. ]
  148. responses = await asyncio.gather(*tasks)
  149. for response in responses:
  150. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  151. @async_test
  152. async def test_sending_bad_messages_not_affect_later_queries_b(self):
  153. resolve, clear_cache = get_resolver(53)
  154. self.add_async_cleanup(clear_cache)
  155. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  156. stop = await start()
  157. self.add_async_cleanup(stop)
  158. response = await resolve('www.google.com', TYPES.A)
  159. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  160. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  161. for _ in range(0, 100000):
  162. sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53))
  163. sock.close()
  164. tasks = [
  165. asyncio.create_task(resolve('www.google.com', TYPES.A))
  166. for _ in range(0, 100000)
  167. ]
  168. responses = await asyncio.gather(*tasks)
  169. for response in responses:
  170. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  171. @async_test
  172. async def test_sending_pointer_loop_not_affect_later_queries_c(self):
  173. resolve, clear_cache = get_resolver(53)
  174. self.add_async_cleanup(clear_cache)
  175. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  176. stop = await start()
  177. self.add_async_cleanup(stop)
  178. response = await resolve('www.google.com', TYPES.A)
  179. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  180. name = b'mydomain.com'
  181. question_record = QuestionRecord(name, TYPES.A, qclass=1)
  182. record_1 = ResourceRecord(
  183. name=name, qtype=TYPES.A, qclass=1, ttl=0,
  184. rdata=ipaddress.IPv4Address('123.100.124.1').packed,
  185. )
  186. response = Message(
  187. qid=1, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  188. qd=(question_record,), an=(record_1,), ns=(), ar=(),
  189. )
  190. data = pack(response)
  191. packed_name = b''.join(
  192. component
  193. for label in name.split(b'.')
  194. for component in (bytes([len(label)]), label)
  195. ) + b'\0'
  196. occurance_1 = data.index(packed_name)
  197. occurance_1_end = occurance_1 + len(packed_name)
  198. occurance_2 = occurance_1_end + data[occurance_1_end:].index(packed_name)
  199. occurance_2_end = occurance_2 + len(packed_name)
  200. data_compressed = \
  201. data[:occurance_2] + \
  202. struct.pack('!H', (192 * 256) + occurance_2 + 4) + \
  203. struct.pack('!H', (192 * 256) + occurance_2) + \
  204. struct.pack('!H', (192 * 256) + occurance_2 + 2) + \
  205. data[occurance_2_end:]
  206. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  207. sock.sendto(data_compressed, ('127.0.0.1', 53))
  208. sock.close()
  209. tasks = [
  210. asyncio.create_task(resolve('www.google.com', TYPES.A))
  211. for _ in range(0, 100000)
  212. ]
  213. responses = await asyncio.gather(*tasks)
  214. for response in responses:
  215. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  216. @async_test
  217. async def test_too_large_response_from_upstream_not_affect_later(self):
  218. num_records = 200
  219. async def get_response(query_data):
  220. query = parse(query_data)
  221. response_records = tuple(
  222. ResourceRecord(
  223. name=query.qd[0].name,
  224. qtype=TYPES.A,
  225. qclass=1,
  226. ttl=0,
  227. rdata=ipaddress.IPv4Address('123.100.123.' + str(i)).packed,
  228. ) for i in range(0, num_records)
  229. )
  230. response = Message(
  231. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  232. qd=query.qd, an=response_records, ns=(), ar=(),
  233. )
  234. return pack(response)
  235. stop_nameserver = await start_nameserver(54, get_response)
  236. self.add_async_cleanup(stop_nameserver)
  237. resolve, clear_cache = get_resolver(53)
  238. self.add_async_cleanup(clear_cache)
  239. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  240. stop = await start()
  241. self.add_async_cleanup(stop)
  242. tasks = [
  243. asyncio.create_task(resolve('www.google.com', TYPES.A))
  244. for _ in range(0, 100000)
  245. ]
  246. for task in tasks:
  247. with self.assertRaises(DnsTimeout):
  248. await task
  249. num_records = 1
  250. tasks = [
  251. asyncio.create_task(resolve('www.google.com', TYPES.A))
  252. for _ in range(0, 100000)
  253. ]
  254. responses = await asyncio.gather(*tasks)
  255. for response in responses:
  256. self.assertEqual(str(response[0]), '123.100.123.0')
  257. def get_socket(port):
  258. def _get_socket():
  259. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  260. sock.setblocking(False)
  261. sock.bind(('', port))
  262. return sock
  263. return _get_socket
  264. def get_small_socket():
  265. # For linux, the minimum buffer size is 1024
  266. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  267. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
  268. sock.setblocking(False)
  269. sock.bind(('', 53))
  270. return sock
  271. def get_resolver(port, timeout=0.5):
  272. async def get_nameservers(_, __):
  273. for _ in range(0, 5):
  274. yield (timeout, ('127.0.0.1', port))
  275. return Resolver(get_nameservers=get_nameservers)
  276. def get_fixed_resolver():
  277. async def get_host(_, fqdn, qtype):
  278. hosts = {
  279. b'www.google.com': {
  280. TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
  281. },
  282. }
  283. try:
  284. return hosts[fqdn.lower()][qtype]
  285. except KeyError:
  286. return None
  287. return Resolver(get_host=get_host)
  288. async def start_nameserver(port, get_response):
  289. # For some tests we need to control the responses from upstream, especially in the cases
  290. # where it's not behaving
  291. loop = asyncio.get_event_loop()
  292. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  293. sock.setblocking(False)
  294. sock.bind(('', port))
  295. async def server():
  296. client_tasks = []
  297. try:
  298. while True:
  299. data, addr = await recvfrom(loop, [sock], 512)
  300. client_tasks.append(asyncio.ensure_future(client_task(data, addr)))
  301. finally:
  302. for task in client_tasks:
  303. task.cancel()
  304. async def client_task(data, addr):
  305. response = await get_response(data)
  306. sock.sendto(response, addr)
  307. server_task = asyncio.ensure_future(server())
  308. async def stop():
  309. server_task.cancel()
  310. await asyncio.sleep(0)
  311. sock.close()
  312. return stop