You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
5.0 KiB

  1. import asyncio
  2. import socket
  3. import unittest
  4. from aiodnsresolver import (
  5. TYPES,
  6. Resolver,
  7. IPv4AddressExpiresAt,
  8. DnsResponseCode,
  9. DnsRecordDoesNotExist,
  10. )
  11. from dnsrewriteproxy import (
  12. DnsProxy,
  13. )
  14. def async_test(func):
  15. def wrapper(*args, **kwargs):
  16. future = func(*args, **kwargs)
  17. loop = asyncio.get_event_loop()
  18. loop.run_until_complete(future)
  19. return wrapper
  20. class TestProxy(unittest.TestCase):
  21. def add_async_cleanup(self, coroutine):
  22. self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())
  23. @async_test
  24. async def test_e2e_no_match_rule(self):
  25. resolve, clear_cache = get_resolver(3535)
  26. self.add_async_cleanup(clear_cache)
  27. start = DnsProxy(get_socket=get_socket(3535))
  28. stop = await start()
  29. self.add_async_cleanup(stop)
  30. with self.assertRaises(DnsResponseCode) as cm:
  31. await resolve('www.google.com', TYPES.A)
  32. self.assertEqual(cm.exception.args[0], 5)
  33. @async_test
  34. async def test_e2e_match_all(self):
  35. resolve, clear_cache = get_resolver(3535)
  36. self.add_async_cleanup(clear_cache)
  37. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  38. stop = await start()
  39. self.add_async_cleanup(stop)
  40. response = await resolve('www.google.com', TYPES.A)
  41. self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt))
  42. @async_test
  43. async def test_e2e_default_port_match_all(self):
  44. resolve, clear_cache = get_resolver(53)
  45. self.add_async_cleanup(clear_cache)
  46. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  47. stop = await start()
  48. self.add_async_cleanup(stop)
  49. response = await resolve('www.google.com', TYPES.A)
  50. self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt))
  51. @async_test
  52. async def test_e2e_default_resolver_match_all_non_existing_domain(self):
  53. resolve, clear_cache = get_resolver(53)
  54. self.add_async_cleanup(clear_cache)
  55. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  56. stop = await start()
  57. self.add_async_cleanup(stop)
  58. with self.assertRaises(DnsRecordDoesNotExist):
  59. await resolve('doesnotexist.charemza.name', TYPES.A)
  60. @async_test
  61. async def test_e2e_default_resolver_match_all_bad_upstream(self):
  62. resolve, clear_cache = get_resolver(53, timeout=100)
  63. self.add_async_cleanup(clear_cache)
  64. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
  65. stop = await start()
  66. self.add_async_cleanup(stop)
  67. with self.assertRaises(DnsResponseCode) as cm:
  68. await resolve('www.google.com', TYPES.A)
  69. self.assertEqual(cm.exception.args[0], 2)
  70. @async_test
  71. async def test_e2e_default_resolver_match_none_non_existing_domain(self):
  72. resolve, clear_cache = get_resolver(53)
  73. self.add_async_cleanup(clear_cache)
  74. start = DnsProxy()
  75. stop = await start()
  76. self.add_async_cleanup(stop)
  77. with self.assertRaises(DnsResponseCode) as cm:
  78. await resolve('doesnotexist.charemza.name', TYPES.A)
  79. self.assertEqual(cm.exception.args[0], 5)
  80. @async_test
  81. async def test_many_responses_with_small_socket_buffer(self):
  82. resolve, clear_cache = get_resolver(53)
  83. self.add_async_cleanup(clear_cache)
  84. start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
  85. get_resolver=get_fixed_resolver)
  86. stop = await start()
  87. self.add_async_cleanup(stop)
  88. tasks = [
  89. asyncio.create_task(resolve('www.google.com', TYPES.A))
  90. for _ in range(0, 100000)
  91. ]
  92. responses = await asyncio.gather(*tasks)
  93. for response in responses:
  94. self.assertEqual(str(response[0]), '1.2.3.4')
  95. bing_responses = await resolve('www.bing.com', TYPES.A)
  96. self.assertTrue(isinstance(bing_responses[0], IPv4AddressExpiresAt))
  97. def get_socket(port):
  98. def _get_socket():
  99. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  100. sock.setblocking(False)
  101. sock.bind(('', port))
  102. return sock
  103. return _get_socket
  104. def get_small_socket():
  105. # For linux, the minimum buffer size is 1024
  106. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  107. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
  108. sock.setblocking(False)
  109. sock.bind(('', 53))
  110. return sock
  111. def get_resolver(port, timeout=0.5):
  112. async def get_nameservers(_, __):
  113. for _ in range(0, 5):
  114. yield (timeout, ('127.0.0.1', port))
  115. return Resolver(get_nameservers=get_nameservers)
  116. def get_fixed_resolver():
  117. async def get_host(_, fqdn, qtype):
  118. hosts = {
  119. b'www.google.com': {
  120. TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
  121. },
  122. }
  123. try:
  124. return hosts[fqdn.lower()][qtype]
  125. except KeyError:
  126. return None
  127. return Resolver(get_host=get_host)