diff --git a/dnsrewriteproxy.py b/dnsrewriteproxy.py index e10506e..ade4e69 100644 --- a/dnsrewriteproxy.py +++ b/dnsrewriteproxy.py @@ -182,9 +182,12 @@ def DnsProxy( except DnsResponseCode as dns_response_code_error: return error(query, dns_response_code_error.args[0]) + def ttl(ip_address): + return int(max(0.0, ip_address.expires_at - loop.time())) + reponse_records = tuple( ResourceRecord(name=name_bytes, qtype=TYPES.A, - qclass=1, ttl=5, rdata=ip_address.packed) + qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed) for ip_address in ip_addresses ) return Message( diff --git a/test.py b/test.py index a4e8531..76fd237 100644 --- a/test.py +++ b/test.py @@ -28,9 +28,9 @@ class TestProxy(unittest.TestCase): @async_test async def test_e2e_no_match_rule(self): - resolve, clear_cache = get_resolver() + resolve, clear_cache = get_resolver(3535) self.add_async_cleanup(clear_cache) - start = DnsProxy(get_socket=get_socket) + start = DnsProxy(get_socket=get_socket(3535)) stop = await start() self.add_async_cleanup(stop) @@ -39,9 +39,33 @@ class TestProxy(unittest.TestCase): @async_test async def test_e2e_match_all(self): - resolve, clear_cache = get_resolver() + resolve, clear_cache = get_resolver(3535) self.add_async_cleanup(clear_cache) - start = DnsProxy(get_socket=get_socket, rules=((r'(^.*$)', r'\1'),)) + start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),)) + stop = await start() + self.add_async_cleanup(stop) + + response = await resolve('www.google.com', TYPES.A) + + self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt)) + + @async_test + async def test_e2e_default_port_match_all(self): + resolve, clear_cache = get_resolver(53) + self.add_async_cleanup(clear_cache) + start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) + stop = await start() + self.add_async_cleanup(stop) + + response = await resolve('www.google.com', TYPES.A) + + self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt)) + + @async_test + async def test_e2e_default_resolver_match_all(self): + resolve, clear_cache = Resolver() + self.add_async_cleanup(clear_cache) + start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) stop = await start() self.add_async_cleanup(stop) @@ -50,16 +74,18 @@ class TestProxy(unittest.TestCase): self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt)) -def get_socket(): - sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) - sock.setblocking(False) - sock.bind(('', 3535)) - return sock +def get_socket(port): + def _get_socket(): + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(('', port)) + return sock + return _get_socket -def get_resolver(): +def get_resolver(port): async def get_nameservers(_, __): for _ in range(0, 5): - yield (0.5, ('127.0.0.1', 3535)) + yield (0.5, ('127.0.0.1', port)) return Resolver(get_nameservers=get_nameservers)