You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

566 lines
19 KiB

  1. import asyncio
  2. import ipaddress
  3. import socket
  4. import struct
  5. import unittest
  6. from aiodnsresolver import (
  7. RESPONSE,
  8. QUESTION,
  9. TYPES,
  10. DnsRecordDoesNotExist,
  11. DnsResponseCode,
  12. DnsTimeout,
  13. IPv4AddressExpiresAt,
  14. Message,
  15. ResourceRecord,
  16. QuestionRecord,
  17. Resolver,
  18. pack,
  19. parse,
  20. recvfrom,
  21. )
  22. from dnsrewriteproxy import (
  23. DnsProxy,
  24. )
  25. def async_test(func):
  26. def wrapper(*args, **kwargs):
  27. future = func(*args, **kwargs)
  28. loop = asyncio.get_event_loop()
  29. loop.run_until_complete(future)
  30. return wrapper
  31. class TestProxy(unittest.TestCase):
  32. def add_async_cleanup(self, coroutine, *args):
  33. self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine(*args))
  34. @async_test
  35. async def test_e2e_no_match_rule(self):
  36. resolve, clear_cache = get_resolver(3535)
  37. self.add_async_cleanup(clear_cache)
  38. start = DnsProxy(get_socket=get_socket(3535))
  39. server_task = await start()
  40. self.add_async_cleanup(await_cancel, server_task)
  41. with self.assertRaises(DnsResponseCode) as cm:
  42. await resolve('www.google.com', TYPES.A)
  43. self.assertEqual(cm.exception.args[0], 5)
  44. @async_test
  45. async def test_e2e_match_all(self):
  46. resolve, clear_cache = get_resolver(3535)
  47. self.add_async_cleanup(clear_cache)
  48. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  49. server_task = await start()
  50. self.add_async_cleanup(await_cancel, server_task)
  51. response = await resolve('www.google.com', TYPES.A)
  52. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  53. @async_test
  54. async def test_e2e_match_all_wrong_type(self):
  55. resolve, clear_cache = get_resolver(3535)
  56. self.add_async_cleanup(clear_cache)
  57. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  58. server_task = await start()
  59. self.add_async_cleanup(await_cancel, server_task)
  60. with self.assertRaises(DnsResponseCode) as cm:
  61. await resolve('www.google.com', TYPES.AAAA)
  62. self.assertEqual(cm.exception.args[0], 5)
  63. @async_test
  64. async def test_e2e_default_port_match_all(self):
  65. resolve, clear_cache = get_resolver(53)
  66. self.add_async_cleanup(clear_cache)
  67. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  68. server_task = await start()
  69. self.add_async_cleanup(await_cancel, server_task)
  70. response = await resolve('www.google.com', TYPES.A)
  71. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  72. @async_test
  73. async def test_e2e_default_resolver_match_all_non_existing_domain(self):
  74. resolve, clear_cache = get_resolver(53)
  75. self.add_async_cleanup(clear_cache)
  76. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  77. server_task = await start()
  78. self.add_async_cleanup(await_cancel, server_task)
  79. with self.assertRaises(DnsRecordDoesNotExist):
  80. await resolve('doesnotexist.charemza.name', TYPES.A)
  81. @async_test
  82. async def test_e2e_default_resolver_rewrite_non_existing_to_existing(self):
  83. resolve, clear_cache = get_resolver(53)
  84. self.add_async_cleanup(clear_cache)
  85. start = DnsProxy(rules=((r'^doesnotexist\.charemza\.name$', r'www.google.com'),))
  86. server_task = await start()
  87. self.add_async_cleanup(await_cancel, server_task)
  88. response = await resolve('doesnotexist.charemza.name', TYPES.A)
  89. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  90. @async_test
  91. async def test_e2e_default_resolver_match_all_bad_upstream(self):
  92. resolve, clear_cache = get_resolver(53, timeout=100)
  93. self.add_async_cleanup(clear_cache)
  94. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  95. server_task = await start()
  96. self.add_async_cleanup(await_cancel, server_task)
  97. with self.assertRaises(DnsResponseCode) as cm:
  98. await resolve('www.google.com', TYPES.A)
  99. self.assertEqual(cm.exception.args[0], 2)
  100. @async_test
  101. async def test_e2e_default_resolver_match_none_non_existing_domain(self):
  102. resolve, clear_cache = get_resolver(53)
  103. self.add_async_cleanup(clear_cache)
  104. start = DnsProxy()
  105. server_task = await start()
  106. self.add_async_cleanup(await_cancel, server_task)
  107. with self.assertRaises(DnsResponseCode) as cm:
  108. await resolve('doesnotexist.charemza.name', TYPES.A)
  109. self.assertEqual(cm.exception.args[0], 5)
  110. @async_test
  111. async def test_many_responses_with_small_socket_buffer_no_onward_query(self):
  112. resolve, clear_cache = get_resolver(53)
  113. self.add_async_cleanup(clear_cache)
  114. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
  115. get_resolver=get_fixed_resolver)
  116. server_task = await start()
  117. self.add_async_cleanup(await_cancel, server_task)
  118. tasks = [
  119. asyncio.create_task(resolve('www.google.com', TYPES.A))
  120. for _ in range(0, 100000)
  121. ]
  122. responses = await asyncio.gather(*tasks)
  123. for response in responses:
  124. self.assertEqual(str(response[0]), '1.2.3.4')
  125. bing_responses = await resolve('www.bing.com', TYPES.A)
  126. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  127. @async_test
  128. async def test_many_responses_with_small_socket_buffer_onward_query(self):
  129. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket)
  130. server_task = await start()
  131. self.add_async_cleanup(await_cancel, server_task)
  132. async def resolve(domain):
  133. resolve, clear_cache = get_resolver(53)
  134. result = await resolve(domain, TYPES.A)
  135. await clear_cache()
  136. return result
  137. tasks = [
  138. asyncio.create_task(resolve('www.google.com'))
  139. for _ in range(0, 1000)
  140. ]
  141. responses = await asyncio.gather(*tasks)
  142. for response in responses:
  143. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  144. bing_responses = await resolve('www.bing.com')
  145. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  146. @async_test
  147. async def test_many_responses_with_regular_socket_buffer_onward_query(self):
  148. resolve, clear_cache = get_resolver(53)
  149. self.add_async_cleanup(clear_cache)
  150. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  151. server_task = await start()
  152. self.add_async_cleanup(await_cancel, server_task)
  153. tasks = [
  154. asyncio.create_task(resolve('www.google.com', TYPES.A))
  155. for _ in range(0, 100000)
  156. ]
  157. responses = await asyncio.gather(*tasks)
  158. for response in responses:
  159. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  160. bing_responses = await resolve('www.bing.com', TYPES.A)
  161. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  162. @async_test
  163. async def test_proxy_returns_error_from_upstream(self):
  164. rcode = 4
  165. async def get_response(query_data):
  166. query = parse(query_data)
  167. response = Message(
  168. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  169. qd=query.qd, an=(), ns=(), ar=(),
  170. )
  171. return pack(response)
  172. stop_nameserver = await start_nameserver(54, get_response)
  173. self.add_async_cleanup(stop_nameserver)
  174. resolve, clear_cache = get_resolver(53)
  175. self.add_async_cleanup(clear_cache)
  176. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  177. server_task = await start()
  178. self.add_async_cleanup(await_cancel, server_task)
  179. with self.assertRaises(DnsResponseCode) as cm:
  180. await resolve('www.google.com', TYPES.A)
  181. self.assertEqual(cm.exception.args[0], 4)
  182. rcode = 5
  183. with self.assertRaises(DnsResponseCode) as cm:
  184. await resolve('www.google.com', TYPES.A)
  185. self.assertEqual(cm.exception.args[0], 5)
  186. @async_test
  187. async def test_sending_bad_messages_not_affect_later_queries_a(self):
  188. resolve, clear_cache = get_resolver(53)
  189. self.add_async_cleanup(clear_cache)
  190. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  191. server_task = await start()
  192. self.add_async_cleanup(await_cancel, server_task)
  193. response = await resolve('www.google.com', TYPES.A)
  194. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  195. for _ in range(0, 100000):
  196. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  197. sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53))
  198. sock.close()
  199. tasks = [
  200. asyncio.create_task(resolve('www.google.com', TYPES.A))
  201. for _ in range(0, 100000)
  202. ]
  203. responses = await asyncio.gather(*tasks)
  204. for response in responses:
  205. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  206. @async_test
  207. async def test_sending_bad_messages_not_affect_later_queries_b(self):
  208. resolve, clear_cache = get_resolver(53)
  209. self.add_async_cleanup(clear_cache)
  210. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  211. server_task = await start()
  212. self.add_async_cleanup(await_cancel, server_task)
  213. response = await resolve('www.google.com', TYPES.A)
  214. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  215. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  216. for _ in range(0, 100000):
  217. sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53))
  218. sock.close()
  219. tasks = [
  220. asyncio.create_task(resolve('www.google.com', TYPES.A))
  221. for _ in range(0, 100000)
  222. ]
  223. responses = await asyncio.gather(*tasks)
  224. for response in responses:
  225. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  226. @async_test
  227. async def test_sending_lots_of_good_messages_not_affect_later_queries(self):
  228. resolve, clear_cache = get_resolver(53)
  229. self.add_async_cleanup(clear_cache)
  230. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  231. server_task = await start()
  232. self.add_async_cleanup(await_cancel, server_task)
  233. response = await resolve('www.google.com', TYPES.A)
  234. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  235. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  236. for i in range(0, 100000):
  237. name = b'doesnotexist' + str(i).encode('ascii') + b'.charemza.name'
  238. question_record = QuestionRecord(name, TYPES.A, qclass=1)
  239. question = Message(
  240. qid=i % 65535, qr=QUESTION, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  241. qd=(question_record,), an=(), ns=(), ar=(),
  242. )
  243. sock.sendto(pack(question), ('127.0.0.1', 53))
  244. sock.close()
  245. response = await resolve('www.google.com', TYPES.A)
  246. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  247. @async_test
  248. async def test_sending_pointer_loop_not_affect_later_queries_c(self):
  249. resolve, clear_cache = get_resolver(53)
  250. self.add_async_cleanup(clear_cache)
  251. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  252. server_task = await start()
  253. self.add_async_cleanup(await_cancel, server_task)
  254. response = await resolve('www.google.com', TYPES.A)
  255. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  256. name = b'mydomain.com'
  257. question_record = QuestionRecord(name, TYPES.A, qclass=1)
  258. record_1 = ResourceRecord(
  259. name=name, qtype=TYPES.A, qclass=1, ttl=0,
  260. rdata=ipaddress.IPv4Address('123.100.124.1').packed,
  261. )
  262. response = Message(
  263. qid=1, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  264. qd=(question_record,), an=(record_1,), ns=(), ar=(),
  265. )
  266. data = pack(response)
  267. packed_name = b''.join(
  268. component
  269. for label in name.split(b'.')
  270. for component in (bytes([len(label)]), label)
  271. ) + b'\0'
  272. occurance_1 = data.index(packed_name)
  273. occurance_1_end = occurance_1 + len(packed_name)
  274. occurance_2 = occurance_1_end + data[occurance_1_end:].index(packed_name)
  275. occurance_2_end = occurance_2 + len(packed_name)
  276. data_compressed = \
  277. data[:occurance_2] + \
  278. struct.pack('!H', (192 * 256) + occurance_2 + 4) + \
  279. struct.pack('!H', (192 * 256) + occurance_2) + \
  280. struct.pack('!H', (192 * 256) + occurance_2 + 2) + \
  281. data[occurance_2_end:]
  282. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  283. sock.sendto(data_compressed, ('127.0.0.1', 53))
  284. sock.close()
  285. tasks = [
  286. asyncio.create_task(resolve('www.google.com', TYPES.A))
  287. for _ in range(0, 100000)
  288. ]
  289. responses = await asyncio.gather(*tasks)
  290. for response in responses:
  291. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  292. @async_test
  293. async def test_too_large_response_from_upstream_not_affect_later(self):
  294. num_records = 200
  295. async def get_response(query_data):
  296. query = parse(query_data)
  297. response_records = tuple(
  298. ResourceRecord(
  299. name=query.qd[0].name,
  300. qtype=TYPES.A,
  301. qclass=1,
  302. ttl=0,
  303. rdata=ipaddress.IPv4Address('123.100.123.' + str(i)).packed,
  304. ) for i in range(0, num_records)
  305. )
  306. response = Message(
  307. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  308. qd=query.qd, an=response_records, ns=(), ar=(),
  309. )
  310. return pack(response)
  311. stop_nameserver = await start_nameserver(54, get_response)
  312. self.add_async_cleanup(stop_nameserver)
  313. resolve, clear_cache = get_resolver(53)
  314. self.add_async_cleanup(clear_cache)
  315. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  316. server_task = await start()
  317. self.add_async_cleanup(await_cancel, server_task)
  318. tasks = [
  319. asyncio.create_task(resolve('www.google.com', TYPES.A))
  320. for _ in range(0, 100000)
  321. ]
  322. for task in tasks:
  323. with self.assertRaises(DnsTimeout):
  324. await task
  325. num_records = 1
  326. tasks = [
  327. asyncio.create_task(resolve('www.google.com', TYPES.A))
  328. for _ in range(0, 100000)
  329. ]
  330. responses = await asyncio.gather(*tasks)
  331. for response in responses:
  332. self.assertEqual(str(response[0]), '123.100.123.0')
  333. @async_test
  334. async def test_server_response_after_cancel_returned_to_client(self):
  335. received_request = asyncio.Event()
  336. continue_request = asyncio.Event()
  337. async def get_response(query_data):
  338. query = parse(query_data)
  339. response_record = ResourceRecord(
  340. name=query.qd[0].name,
  341. qtype=TYPES.A,
  342. qclass=1,
  343. ttl=0,
  344. rdata=ipaddress.IPv4Address('123.100.123.1').packed,
  345. )
  346. response = Message(
  347. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  348. qd=query.qd, an=(response_record,), ns=(), ar=(),
  349. )
  350. received_request.set()
  351. await continue_request.wait()
  352. return pack(response)
  353. stop_nameserver = await start_nameserver(54, get_response)
  354. self.add_async_cleanup(stop_nameserver)
  355. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  356. server_task = await start()
  357. async def resolve(domain):
  358. resolve, clear_cache = get_resolver(53)
  359. result = await resolve(domain, TYPES.A)
  360. await clear_cache()
  361. return result
  362. # Start a set of requests
  363. tasks = [
  364. asyncio.create_task(resolve('www.google.com'))
  365. for _ in range(0, 100)
  366. ]
  367. await received_request.wait()
  368. # Cancel the server...
  369. server_task.cancel()
  370. # ... start a new request
  371. after_cancel_task = asyncio.create_task(resolve('www.bing.com'))
  372. # ... wait to try to ensure the request would have been received
  373. await asyncio.sleep(0.2)
  374. # ... then finally the upstream server continues with the processing
  375. # of the requests received before cancellation
  376. continue_request.set()
  377. for response in await asyncio.gather(*tasks):
  378. self.assertEqual(str(response[0]), '123.100.123.1')
  379. # ... but the request started after cancellation times out
  380. with self.assertRaises(DnsTimeout):
  381. await after_cancel_task
  382. def get_socket(port):
  383. def _get_socket():
  384. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  385. sock.setblocking(False)
  386. sock.bind(('', port))
  387. return sock
  388. return _get_socket
  389. def get_small_socket():
  390. # For linux, the minimum buffer size is 1024
  391. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  392. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
  393. sock.setblocking(False)
  394. sock.bind(('', 53))
  395. return sock
  396. def get_resolver(port, timeout=2.0):
  397. async def get_nameservers(_, __):
  398. for _ in range(0, 5):
  399. yield (timeout, ('127.0.0.1', port))
  400. return Resolver(get_nameservers=get_nameservers)
  401. def get_fixed_resolver():
  402. async def get_host(_, fqdn, qtype):
  403. hosts = {
  404. b'www.google.com': {
  405. TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
  406. },
  407. }
  408. try:
  409. return hosts[fqdn.lower()][qtype]
  410. except KeyError:
  411. return None
  412. return Resolver(get_host=get_host)
  413. async def start_nameserver(port, get_response):
  414. # For some tests we need to control the responses from upstream, especially in the cases
  415. # where it's not behaving
  416. loop = asyncio.get_event_loop()
  417. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  418. sock.setblocking(False)
  419. sock.bind(('', port))
  420. async def server():
  421. client_tasks = []
  422. try:
  423. while True:
  424. data, addr = await recvfrom(loop, [sock], 512)
  425. client_tasks.append(asyncio.ensure_future(client_task(data, addr)))
  426. finally:
  427. for task in client_tasks:
  428. task.cancel()
  429. async def client_task(data, addr):
  430. response = await get_response(data)
  431. sock.sendto(response, addr)
  432. server_task = asyncio.ensure_future(server())
  433. async def stop():
  434. server_task.cancel()
  435. await asyncio.sleep(0)
  436. sock.close()
  437. return stop
  438. async def await_cancel(task):
  439. task.cancel()
  440. try:
  441. await task
  442. except asyncio.CancelledError:
  443. pass