You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.6 KiB

  1. import asyncio
  2. import socket
  3. import unittest
  4. from aiodnsresolver import (
  5. TYPES,
  6. Resolver,
  7. IPv4AddressExpiresAt,
  8. DnsResponseCode,
  9. DnsRecordDoesNotExist,
  10. )
  11. from dnsrewriteproxy import (
  12. DnsProxy,
  13. )
  14. def async_test(func):
  15. def wrapper(*args, **kwargs):
  16. future = func(*args, **kwargs)
  17. loop = asyncio.get_event_loop()
  18. loop.run_until_complete(future)
  19. return wrapper
  20. class TestProxy(unittest.TestCase):
  21. def add_async_cleanup(self, coroutine):
  22. self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())
  23. @async_test
  24. async def test_e2e_no_match_rule(self):
  25. resolve, clear_cache = get_resolver(3535)
  26. self.add_async_cleanup(clear_cache)
  27. start = DnsProxy(get_socket=get_socket(3535))
  28. stop = await start()
  29. self.add_async_cleanup(stop)
  30. with self.assertRaises(DnsResponseCode):
  31. await resolve('www.google.com', TYPES.A)
  32. @async_test
  33. async def test_e2e_match_all(self):
  34. resolve, clear_cache = get_resolver(3535)
  35. self.add_async_cleanup(clear_cache)
  36. start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
  37. stop = await start()
  38. self.add_async_cleanup(stop)
  39. response = await resolve('www.google.com', TYPES.A)
  40. self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt))
  41. @async_test
  42. async def test_e2e_default_port_match_all(self):
  43. resolve, clear_cache = get_resolver(53)
  44. self.add_async_cleanup(clear_cache)
  45. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  46. stop = await start()
  47. self.add_async_cleanup(stop)
  48. response = await resolve('www.google.com', TYPES.A)
  49. self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt))
  50. @async_test
  51. async def test_e2e_default_resolver_match_all_non_existing_domain(self):
  52. resolve, clear_cache = get_resolver(53)
  53. self.add_async_cleanup(clear_cache)
  54. start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
  55. stop = await start()
  56. self.add_async_cleanup(stop)
  57. with self.assertRaises(DnsRecordDoesNotExist):
  58. await resolve('doesnotexist.charemza.name', TYPES.A)
  59. def get_socket(port):
  60. def _get_socket():
  61. sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  62. sock.setblocking(False)
  63. sock.bind(('', port))
  64. return sock
  65. return _get_socket
  66. def get_resolver(port):
  67. async def get_nameservers(_, __):
  68. for _ in range(0, 5):
  69. yield (0.5, ('127.0.0.1', port))
  70. return Resolver(get_nameservers=get_nameservers)