You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

553 lines
18 KiB

  1. import asyncio
  2. import ipaddress
  3. import socket
  4. import struct
  5. import unittest
  6. from aiodnsresolver import (
  7. RESPONSE,
  8. QUESTION,
  9. TYPES,
  10. DnsRecordDoesNotExist,
  11. DnsResponseCode,
  12. DnsTimeout,
  13. IPv4AddressExpiresAt,
  14. Message,
  15. ResourceRecord,
  16. QuestionRecord,
  17. Resolver,
  18. pack,
  19. parse,
  20. recvfrom,
  21. )
  22. from dnsrewriteproxy import (
  23. DnsProxy,
  24. )
  25. def async_test(func):
  26. def wrapper(*args, **kwargs):
  27. future = func(*args, **kwargs)
  28. loop = asyncio.get_event_loop()
  29. loop.run_until_complete(future)
  30. return wrapper
  31. class TestProxy(unittest.TestCase):
  32. def add_async_cleanup(self, coroutine, *args):
  33. self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine(*args))
  34. @async_test
  35. async def test_e2e_no_match_rule(self):
  36. resolve, clear_cache = get_resolver(3535)
  37. self.add_async_cleanup(clear_cache)
  38. start = DnsProxy(get_socket=get_socket(3535))
  39. server_task = await start()
  40. self.add_async_cleanup(await_cancel, server_task)
  41. with self.assertRaises(DnsResponseCode) as cm:
  42. await resolve('www.google.com', TYPES.A)
  43. self.assertEqual(cm.exception.args[0], 5)
  44. @async_test
  45. async def test_e2e_match_all(self):
  46. resolve, clear_cache = get_resolver(3535)
  47. self.add_async_cleanup(clear_cache)
  48. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  49. server_task = await start()
  50. self.add_async_cleanup(await_cancel, server_task)
  51. response = await resolve('www.google.com', TYPES.A)
  52. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  53. @async_test
  54. async def test_e2e_default_port_match_all(self):
  55. resolve, clear_cache = get_resolver(53)
  56. self.add_async_cleanup(clear_cache)
  57. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  58. server_task = await start()
  59. self.add_async_cleanup(await_cancel, server_task)
  60. response = await resolve('www.google.com', TYPES.A)
  61. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  62. @async_test
  63. async def test_e2e_default_resolver_match_all_non_existing_domain(self):
  64. resolve, clear_cache = get_resolver(53)
  65. self.add_async_cleanup(clear_cache)
  66. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  67. server_task = await start()
  68. self.add_async_cleanup(await_cancel, server_task)
  69. with self.assertRaises(DnsRecordDoesNotExist):
  70. await resolve('doesnotexist.charemza.name', TYPES.A)
  71. @async_test
  72. async def test_e2e_default_resolver_rewrite_non_existing_to_existing(self):
  73. resolve, clear_cache = get_resolver(53)
  74. self.add_async_cleanup(clear_cache)
  75. start = DnsProxy(rules=((r'^doesnotexist\.charemza\.name$', r'www.google.com'),))
  76. server_task = await start()
  77. self.add_async_cleanup(await_cancel, server_task)
  78. response = await resolve('doesnotexist.charemza.name', TYPES.A)
  79. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  80. @async_test
  81. async def test_e2e_default_resolver_match_all_bad_upstream(self):
  82. resolve, clear_cache = get_resolver(53, timeout=100)
  83. self.add_async_cleanup(clear_cache)
  84. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  85. server_task = await start()
  86. self.add_async_cleanup(await_cancel, server_task)
  87. with self.assertRaises(DnsResponseCode) as cm:
  88. await resolve('www.google.com', TYPES.A)
  89. self.assertEqual(cm.exception.args[0], 2)
  90. @async_test
  91. async def test_e2e_default_resolver_match_none_non_existing_domain(self):
  92. resolve, clear_cache = get_resolver(53)
  93. self.add_async_cleanup(clear_cache)
  94. start = DnsProxy()
  95. server_task = await start()
  96. self.add_async_cleanup(await_cancel, server_task)
  97. with self.assertRaises(DnsResponseCode) as cm:
  98. await resolve('doesnotexist.charemza.name', TYPES.A)
  99. self.assertEqual(cm.exception.args[0], 5)
  100. @async_test
  101. async def test_many_responses_with_small_socket_buffer_no_onward_query(self):
  102. resolve, clear_cache = get_resolver(53)
  103. self.add_async_cleanup(clear_cache)
  104. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
  105. get_resolver=get_fixed_resolver)
  106. server_task = await start()
  107. self.add_async_cleanup(await_cancel, server_task)
  108. tasks = [
  109. asyncio.create_task(resolve('www.google.com', TYPES.A))
  110. for _ in range(0, 100000)
  111. ]
  112. responses = await asyncio.gather(*tasks)
  113. for response in responses:
  114. self.assertEqual(str(response[0]), '1.2.3.4')
  115. bing_responses = await resolve('www.bing.com', TYPES.A)
  116. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  117. @async_test
  118. async def test_many_responses_with_small_socket_buffer_onward_query(self):
  119. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket)
  120. server_task = await start()
  121. self.add_async_cleanup(await_cancel, server_task)
  122. async def resolve(domain):
  123. resolve, clear_cache = get_resolver(53)
  124. result = await resolve(domain, TYPES.A)
  125. await clear_cache()
  126. return result
  127. tasks = [
  128. asyncio.create_task(resolve('www.google.com'))
  129. for _ in range(0, 1000)
  130. ]
  131. responses = await asyncio.gather(*tasks)
  132. for response in responses:
  133. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  134. bing_responses = await resolve('www.bing.com')
  135. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  136. @async_test
  137. async def test_many_responses_with_regular_socket_buffer_onward_query(self):
  138. resolve, clear_cache = get_resolver(53)
  139. self.add_async_cleanup(clear_cache)
  140. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  141. server_task = await start()
  142. self.add_async_cleanup(await_cancel, server_task)
  143. tasks = [
  144. asyncio.create_task(resolve('www.google.com', TYPES.A))
  145. for _ in range(0, 100000)
  146. ]
  147. responses = await asyncio.gather(*tasks)
  148. for response in responses:
  149. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  150. bing_responses = await resolve('www.bing.com', TYPES.A)
  151. self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt)
  152. @async_test
  153. async def test_proxy_returns_error_from_upstream(self):
  154. rcode = 4
  155. async def get_response(query_data):
  156. query = parse(query_data)
  157. response = Message(
  158. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
  159. qd=query.qd, an=(), ns=(), ar=(),
  160. )
  161. return pack(response)
  162. stop_nameserver = await start_nameserver(54, get_response)
  163. self.add_async_cleanup(stop_nameserver)
  164. resolve, clear_cache = get_resolver(53)
  165. self.add_async_cleanup(clear_cache)
  166. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  167. server_task = await start()
  168. self.add_async_cleanup(await_cancel, server_task)
  169. with self.assertRaises(DnsResponseCode) as cm:
  170. await resolve('www.google.com', TYPES.A)
  171. self.assertEqual(cm.exception.args[0], 4)
  172. rcode = 5
  173. with self.assertRaises(DnsResponseCode) as cm:
  174. await resolve('www.google.com', TYPES.A)
  175. self.assertEqual(cm.exception.args[0], 5)
  176. @async_test
  177. async def test_sending_bad_messages_not_affect_later_queries_a(self):
  178. resolve, clear_cache = get_resolver(53)
  179. self.add_async_cleanup(clear_cache)
  180. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  181. server_task = await start()
  182. self.add_async_cleanup(await_cancel, server_task)
  183. response = await resolve('www.google.com', TYPES.A)
  184. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  185. for _ in range(0, 100000):
  186. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  187. sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53))
  188. sock.close()
  189. tasks = [
  190. asyncio.create_task(resolve('www.google.com', TYPES.A))
  191. for _ in range(0, 100000)
  192. ]
  193. responses = await asyncio.gather(*tasks)
  194. for response in responses:
  195. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  196. @async_test
  197. async def test_sending_bad_messages_not_affect_later_queries_b(self):
  198. resolve, clear_cache = get_resolver(53)
  199. self.add_async_cleanup(clear_cache)
  200. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  201. server_task = await start()
  202. self.add_async_cleanup(await_cancel, server_task)
  203. response = await resolve('www.google.com', TYPES.A)
  204. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  205. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  206. for _ in range(0, 100000):
  207. sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53))
  208. sock.close()
  209. tasks = [
  210. asyncio.create_task(resolve('www.google.com', TYPES.A))
  211. for _ in range(0, 100000)
  212. ]
  213. responses = await asyncio.gather(*tasks)
  214. for response in responses:
  215. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  216. @async_test
  217. async def test_sending_lots_of_good_messages_not_affect_later_queries(self):
  218. resolve, clear_cache = get_resolver(53)
  219. self.add_async_cleanup(clear_cache)
  220. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  221. server_task = await start()
  222. self.add_async_cleanup(await_cancel, server_task)
  223. response = await resolve('www.google.com', TYPES.A)
  224. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  225. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  226. for i in range(0, 100000):
  227. name = b'doesnotexist' + str(i).encode('ascii') + b'.charemza.name'
  228. question_record = QuestionRecord(name, TYPES.A, qclass=1)
  229. question = Message(
  230. qid=i % 65535, qr=QUESTION, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  231. qd=(question_record,), an=(), ns=(), ar=(),
  232. )
  233. sock.sendto(pack(question), ('127.0.0.1', 53))
  234. sock.close()
  235. response = await resolve('www.google.com', TYPES.A)
  236. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  237. @async_test
  238. async def test_sending_pointer_loop_not_affect_later_queries_c(self):
  239. resolve, clear_cache = get_resolver(53)
  240. self.add_async_cleanup(clear_cache)
  241. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  242. server_task = await start()
  243. self.add_async_cleanup(await_cancel, server_task)
  244. response = await resolve('www.google.com', TYPES.A)
  245. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  246. name = b'mydomain.com'
  247. question_record = QuestionRecord(name, TYPES.A, qclass=1)
  248. record_1 = ResourceRecord(
  249. name=name, qtype=TYPES.A, qclass=1, ttl=0,
  250. rdata=ipaddress.IPv4Address('123.100.124.1').packed,
  251. )
  252. response = Message(
  253. qid=1, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  254. qd=(question_record,), an=(record_1,), ns=(), ar=(),
  255. )
  256. data = pack(response)
  257. packed_name = b''.join(
  258. component
  259. for label in name.split(b'.')
  260. for component in (bytes([len(label)]), label)
  261. ) + b'\0'
  262. occurance_1 = data.index(packed_name)
  263. occurance_1_end = occurance_1 + len(packed_name)
  264. occurance_2 = occurance_1_end + data[occurance_1_end:].index(packed_name)
  265. occurance_2_end = occurance_2 + len(packed_name)
  266. data_compressed = \
  267. data[:occurance_2] + \
  268. struct.pack('!H', (192 * 256) + occurance_2 + 4) + \
  269. struct.pack('!H', (192 * 256) + occurance_2) + \
  270. struct.pack('!H', (192 * 256) + occurance_2 + 2) + \
  271. data[occurance_2_end:]
  272. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  273. sock.sendto(data_compressed, ('127.0.0.1', 53))
  274. sock.close()
  275. tasks = [
  276. asyncio.create_task(resolve('www.google.com', TYPES.A))
  277. for _ in range(0, 100000)
  278. ]
  279. responses = await asyncio.gather(*tasks)
  280. for response in responses:
  281. self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
  282. @async_test
  283. async def test_too_large_response_from_upstream_not_affect_later(self):
  284. num_records = 200
  285. async def get_response(query_data):
  286. query = parse(query_data)
  287. response_records = tuple(
  288. ResourceRecord(
  289. name=query.qd[0].name,
  290. qtype=TYPES.A,
  291. qclass=1,
  292. ttl=0,
  293. rdata=ipaddress.IPv4Address('123.100.123.' + str(i)).packed,
  294. ) for i in range(0, num_records)
  295. )
  296. response = Message(
  297. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  298. qd=query.qd, an=response_records, ns=(), ar=(),
  299. )
  300. return pack(response)
  301. stop_nameserver = await start_nameserver(54, get_response)
  302. self.add_async_cleanup(stop_nameserver)
  303. resolve, clear_cache = get_resolver(53)
  304. self.add_async_cleanup(clear_cache)
  305. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  306. server_task = await start()
  307. self.add_async_cleanup(await_cancel, server_task)
  308. tasks = [
  309. asyncio.create_task(resolve('www.google.com', TYPES.A))
  310. for _ in range(0, 100000)
  311. ]
  312. for task in tasks:
  313. with self.assertRaises(DnsTimeout):
  314. await task
  315. num_records = 1
  316. tasks = [
  317. asyncio.create_task(resolve('www.google.com', TYPES.A))
  318. for _ in range(0, 100000)
  319. ]
  320. responses = await asyncio.gather(*tasks)
  321. for response in responses:
  322. self.assertEqual(str(response[0]), '123.100.123.0')
  323. @async_test
  324. async def test_server_response_after_cancel_returned_to_client(self):
  325. received_request = asyncio.Event()
  326. continue_request = asyncio.Event()
  327. async def get_response(query_data):
  328. query = parse(query_data)
  329. response_record = ResourceRecord(
  330. name=query.qd[0].name,
  331. qtype=TYPES.A,
  332. qclass=1,
  333. ttl=0,
  334. rdata=ipaddress.IPv4Address('123.100.123.1').packed,
  335. )
  336. response = Message(
  337. qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
  338. qd=query.qd, an=(response_record,), ns=(), ar=(),
  339. )
  340. received_request.set()
  341. await continue_request.wait()
  342. return pack(response)
  343. stop_nameserver = await start_nameserver(54, get_response)
  344. self.add_async_cleanup(stop_nameserver)
  345. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  346. server_task = await start()
  347. async def resolve(domain):
  348. resolve, clear_cache = get_resolver(53)
  349. result = await resolve(domain, TYPES.A)
  350. await clear_cache()
  351. return result
  352. # Start a set of requests
  353. tasks = [
  354. asyncio.create_task(resolve('www.google.com'))
  355. for _ in range(0, 100)
  356. ]
  357. await received_request.wait()
  358. # Cancel the server...
  359. server_task.cancel()
  360. # ... start a new request
  361. after_cancel_task = asyncio.create_task(resolve('www.bing.com'))
  362. # ... wait to try to ensure the request would have been received
  363. await asyncio.sleep(0.2)
  364. # ... then finally the upstream server continues with the processing
  365. # of the requests received before cancellation
  366. continue_request.set()
  367. for response in await asyncio.gather(*tasks):
  368. self.assertEqual(str(response[0]), '123.100.123.1')
  369. # ... but the request started after cancellation times out
  370. with self.assertRaises(DnsTimeout):
  371. await after_cancel_task
  372. def get_socket(port):
  373. def _get_socket():
  374. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  375. sock.setblocking(False)
  376. sock.bind(('', port))
  377. return sock
  378. return _get_socket
  379. def get_small_socket():
  380. # For linux, the minimum buffer size is 1024
  381. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  382. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
  383. sock.setblocking(False)
  384. sock.bind(('', 53))
  385. return sock
  386. def get_resolver(port, timeout=2.0):
  387. async def get_nameservers(_, __):
  388. for _ in range(0, 5):
  389. yield (timeout, ('127.0.0.1', port))
  390. return Resolver(get_nameservers=get_nameservers)
  391. def get_fixed_resolver():
  392. async def get_host(_, fqdn, qtype):
  393. hosts = {
  394. b'www.google.com': {
  395. TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
  396. },
  397. }
  398. try:
  399. return hosts[fqdn.lower()][qtype]
  400. except KeyError:
  401. return None
  402. return Resolver(get_host=get_host)
  403. async def start_nameserver(port, get_response):
  404. # For some tests we need to control the responses from upstream, especially in the cases
  405. # where it's not behaving
  406. loop = asyncio.get_event_loop()
  407. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  408. sock.setblocking(False)
  409. sock.bind(('', port))
  410. async def server():
  411. client_tasks = []
  412. try:
  413. while True:
  414. data, addr = await recvfrom(loop, [sock], 512)
  415. client_tasks.append(asyncio.ensure_future(client_task(data, addr)))
  416. finally:
  417. for task in client_tasks:
  418. task.cancel()
  419. async def client_task(data, addr):
  420. response = await get_response(data)
  421. sock.sendto(response, addr)
  422. server_task = asyncio.ensure_future(server())
  423. async def stop():
  424. server_task.cancel()
  425. await asyncio.sleep(0)
  426. sock.close()
  427. return stop
  428. async def await_cancel(task):
  429. task.cancel()
  430. try:
  431. await task
  432. except asyncio.CancelledError:
  433. pass