diff --git a/test.py b/test.py index ad710b5..01c6524 100644 --- a/test.py +++ b/test.py @@ -4,11 +4,16 @@ import unittest from aiodnsresolver import ( + RESPONSE, TYPES, - Resolver, - IPv4AddressExpiresAt, - DnsResponseCode, DnsRecordDoesNotExist, + DnsResponseCode, + IPv4AddressExpiresAt, + Message, + Resolver, + pack, + parse, + recvfrom, ) from dnsrewriteproxy import ( DnsProxy, @@ -125,6 +130,39 @@ class TestProxy(unittest.TestCase): bing_responses = await resolve('www.bing.com', TYPES.A) self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt) + @async_test + async def test_proxy_returns_error_from_upstream(self): + rcode = 4 + + async def get_response(query_data): + query = parse(query_data) + response = Message( + qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, + qd=query.qd, an=(), ns=(), ar=(), + ) + return pack(response) + + stop_nameserver = await start_nameserver(54, get_response) + self.add_async_cleanup(stop_nameserver) + + resolve, clear_cache = get_resolver(53) + self.add_async_cleanup(clear_cache) + + start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) + stop = await start() + self.add_async_cleanup(stop) + + with self.assertRaises(DnsResponseCode) as cm: + await resolve('www.google.com', TYPES.A) + + self.assertEqual(cm.exception.args[0], 4) + + rcode = 5 + with self.assertRaises(DnsResponseCode) as cm: + await resolve('www.google.com', TYPES.A) + + self.assertEqual(cm.exception.args[0], 5) + def get_socket(port): def _get_socket(): @@ -165,3 +203,36 @@ def get_fixed_resolver(): return None return Resolver(get_host=get_host) + + +async def start_nameserver(port, get_response): + # For some tests we need to control the responses from upstream, especially in the cases + # where it's not behaving + loop = asyncio.get_event_loop() + + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(('', port)) + + async def server(): + client_tasks = [] + try: + while True: + data, addr = await recvfrom(loop, [sock], 512) + client_tasks.append(asyncio.ensure_future(client_task(data, addr))) + finally: + for task in client_tasks: + task.cancel() + + async def client_task(data, addr): + response = await get_response(data) + sock.sendto(response, addr) + + server_task = asyncio.ensure_future(server()) + + async def stop(): + server_task.cancel() + await asyncio.sleep(0) + sock.close() + + return stop