| @@ -4,11 +4,16 @@ import unittest | |||
| from aiodnsresolver import ( | |||
| RESPONSE, | |||
| TYPES, | |||
| Resolver, | |||
| IPv4AddressExpiresAt, | |||
| DnsResponseCode, | |||
| DnsRecordDoesNotExist, | |||
| DnsResponseCode, | |||
| IPv4AddressExpiresAt, | |||
| Message, | |||
| Resolver, | |||
| pack, | |||
| parse, | |||
| recvfrom, | |||
| ) | |||
| from dnsrewriteproxy import ( | |||
| DnsProxy, | |||
| @@ -125,6 +130,39 @@ class TestProxy(unittest.TestCase): | |||
| bing_responses = await resolve('www.bing.com', TYPES.A) | |||
| self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt) | |||
| @async_test | |||
| async def test_proxy_returns_error_from_upstream(self): | |||
| rcode = 4 | |||
| async def get_response(query_data): | |||
| query = parse(query_data) | |||
| response = Message( | |||
| qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, | |||
| qd=query.qd, an=(), ns=(), ar=(), | |||
| ) | |||
| return pack(response) | |||
| stop_nameserver = await start_nameserver(54, get_response) | |||
| self.add_async_cleanup(stop_nameserver) | |||
| resolve, clear_cache = get_resolver(53) | |||
| self.add_async_cleanup(clear_cache) | |||
| start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) | |||
| stop = await start() | |||
| self.add_async_cleanup(stop) | |||
| with self.assertRaises(DnsResponseCode) as cm: | |||
| await resolve('www.google.com', TYPES.A) | |||
| self.assertEqual(cm.exception.args[0], 4) | |||
| rcode = 5 | |||
| with self.assertRaises(DnsResponseCode) as cm: | |||
| await resolve('www.google.com', TYPES.A) | |||
| self.assertEqual(cm.exception.args[0], 5) | |||
| def get_socket(port): | |||
| def _get_socket(): | |||
| @@ -165,3 +203,36 @@ def get_fixed_resolver(): | |||
| return None | |||
| return Resolver(get_host=get_host) | |||
| async def start_nameserver(port, get_response): | |||
| # For some tests we need to control the responses from upstream, especially in the cases | |||
| # where it's not behaving | |||
| loop = asyncio.get_event_loop() | |||
| sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) | |||
| sock.setblocking(False) | |||
| sock.bind(('', port)) | |||
| async def server(): | |||
| client_tasks = [] | |||
| try: | |||
| while True: | |||
| data, addr = await recvfrom(loop, [sock], 512) | |||
| client_tasks.append(asyncio.ensure_future(client_task(data, addr))) | |||
| finally: | |||
| for task in client_tasks: | |||
| task.cancel() | |||
| async def client_task(data, addr): | |||
| response = await get_response(data) | |||
| sock.sendto(response, addr) | |||
| server_task = asyncio.ensure_future(server()) | |||
| async def stop(): | |||
| server_task.cancel() | |||
| await asyncio.sleep(0) | |||
| sock.close() | |||
| return stop | |||