import asyncio import ipaddress import socket import struct import unittest from aiodnsresolver import ( RESPONSE, QUESTION, TYPES, DnsRecordDoesNotExist, DnsResponseCode, DnsTimeout, IPv4AddressExpiresAt, Message, ResourceRecord, QuestionRecord, Resolver, pack, parse, recvfrom, ) from dnsrewriteproxy import ( DnsProxy, ) def async_test(func): def wrapper(*args, **kwargs): future = func(*args, **kwargs) loop = asyncio.get_event_loop() loop.run_until_complete(future) return wrapper class TestProxy(unittest.TestCase): def add_async_cleanup(self, coroutine, *args): self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine(*args)) @async_test async def test_e2e_no_match_rule(self): resolve, clear_cache = get_resolver(3535) self.add_async_cleanup(clear_cache) start = DnsProxy(get_socket=get_socket(3535)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) with self.assertRaises(DnsResponseCode) as cm: await resolve('www.google.com', TYPES.A) self.assertEqual(cm.exception.args[0], 5) @async_test async def test_e2e_match_all(self): resolve, clear_cache = get_resolver(3535) self.add_async_cleanup(clear_cache) start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_e2e_default_port_match_all(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_e2e_default_resolver_match_all_non_existing_domain(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) with self.assertRaises(DnsRecordDoesNotExist): await resolve('doesnotexist.charemza.name', TYPES.A) @async_test async def test_e2e_default_resolver_rewrite_non_existing_to_existing(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'^doesnotexist\.charemza\.name$', r'www.google.com'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('doesnotexist.charemza.name', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_e2e_default_resolver_match_all_bad_upstream(self): resolve, clear_cache = get_resolver(53, timeout=100) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) with self.assertRaises(DnsResponseCode) as cm: await resolve('www.google.com', TYPES.A) self.assertEqual(cm.exception.args[0], 2) @async_test async def test_e2e_default_resolver_match_none_non_existing_domain(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy() server_task = await start() self.add_async_cleanup(await_cancel, server_task) with self.assertRaises(DnsResponseCode) as cm: await resolve('doesnotexist.charemza.name', TYPES.A) self.assertEqual(cm.exception.args[0], 5) @async_test async def test_many_responses_with_small_socket_buffer_no_onward_query(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket, get_resolver=get_fixed_resolver) server_task = await start() self.add_async_cleanup(await_cancel, server_task) tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(str(response[0]), '1.2.3.4') bing_responses = await resolve('www.bing.com', TYPES.A) self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt) @async_test async def test_many_responses_with_small_socket_buffer_onward_query(self): start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket) server_task = await start() self.add_async_cleanup(await_cancel, server_task) async def resolve(domain): resolve, clear_cache = get_resolver(53) result = await resolve(domain, TYPES.A) await clear_cache() return result tasks = [ asyncio.create_task(resolve('www.google.com')) for _ in range(0, 1000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(type(response[0]), IPv4AddressExpiresAt) bing_responses = await resolve('www.bing.com') self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt) @async_test async def test_many_responses_with_regular_socket_buffer_onward_query(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(type(response[0]), IPv4AddressExpiresAt) bing_responses = await resolve('www.bing.com', TYPES.A) self.assertEqual(type(bing_responses[0]), IPv4AddressExpiresAt) @async_test async def test_proxy_returns_error_from_upstream(self): rcode = 4 async def get_response(query_data): query = parse(query_data) response = Message( qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, qd=query.qd, an=(), ns=(), ar=(), ) return pack(response) stop_nameserver = await start_nameserver(54, get_response) self.add_async_cleanup(stop_nameserver) resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) with self.assertRaises(DnsResponseCode) as cm: await resolve('www.google.com', TYPES.A) self.assertEqual(cm.exception.args[0], 4) rcode = 5 with self.assertRaises(DnsResponseCode) as cm: await resolve('www.google.com', TYPES.A) self.assertEqual(cm.exception.args[0], 5) @async_test async def test_sending_bad_messages_not_affect_later_queries_a(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) for _ in range(0, 100000): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53)) sock.close() tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_sending_bad_messages_not_affect_later_queries_b(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) for _ in range(0, 100000): sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53)) sock.close() tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_sending_lots_of_good_messages_not_affect_later_queries(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) for i in range(0, 100000): name = b'doesnotexist' + str(i).encode('ascii') + b'.charemza.name' question_record = QuestionRecord(name, TYPES.A, qclass=1) question = Message( qid=i % 65535, qr=QUESTION, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, qd=(question_record,), an=(), ns=(), ar=(), ) sock.sendto(pack(question), ('127.0.0.1', 53)) sock.close() response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_sending_pointer_loop_not_affect_later_queries_c(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) response = await resolve('www.google.com', TYPES.A) self.assertEqual(type(response[0]), IPv4AddressExpiresAt) name = b'mydomain.com' question_record = QuestionRecord(name, TYPES.A, qclass=1) record_1 = ResourceRecord( name=name, qtype=TYPES.A, qclass=1, ttl=0, rdata=ipaddress.IPv4Address('123.100.124.1').packed, ) response = Message( qid=1, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, qd=(question_record,), an=(record_1,), ns=(), ar=(), ) data = pack(response) packed_name = b''.join( component for label in name.split(b'.') for component in (bytes([len(label)]), label) ) + b'\0' occurance_1 = data.index(packed_name) occurance_1_end = occurance_1 + len(packed_name) occurance_2 = occurance_1_end + data[occurance_1_end:].index(packed_name) occurance_2_end = occurance_2 + len(packed_name) data_compressed = \ data[:occurance_2] + \ struct.pack('!H', (192 * 256) + occurance_2 + 4) + \ struct.pack('!H', (192 * 256) + occurance_2) + \ struct.pack('!H', (192 * 256) + occurance_2 + 2) + \ data[occurance_2_end:] sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.sendto(data_compressed, ('127.0.0.1', 53)) sock.close() tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(type(response[0]), IPv4AddressExpiresAt) @async_test async def test_too_large_response_from_upstream_not_affect_later(self): num_records = 200 async def get_response(query_data): query = parse(query_data) response_records = tuple( ResourceRecord( name=query.qd[0].name, qtype=TYPES.A, qclass=1, ttl=0, rdata=ipaddress.IPv4Address('123.100.123.' + str(i)).packed, ) for i in range(0, num_records) ) response = Message( qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, qd=query.qd, an=response_records, ns=(), ar=(), ) return pack(response) stop_nameserver = await start_nameserver(54, get_response) self.add_async_cleanup(stop_nameserver) resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) server_task = await start() self.add_async_cleanup(await_cancel, server_task) tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] for task in tasks: with self.assertRaises(DnsTimeout): await task num_records = 1 tasks = [ asyncio.create_task(resolve('www.google.com', TYPES.A)) for _ in range(0, 100000) ] responses = await asyncio.gather(*tasks) for response in responses: self.assertEqual(str(response[0]), '123.100.123.0') @async_test async def test_server_response_after_cancel_returned_to_client(self): received_request = asyncio.Event() continue_request = asyncio.Event() async def get_response(query_data): query = parse(query_data) response_record = ResourceRecord( name=query.qd[0].name, qtype=TYPES.A, qclass=1, ttl=0, rdata=ipaddress.IPv4Address('123.100.123.1').packed, ) response = Message( qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, qd=query.qd, an=(response_record,), ns=(), ar=(), ) received_request.set() await continue_request.wait() return pack(response) stop_nameserver = await start_nameserver(54, get_response) self.add_async_cleanup(stop_nameserver) start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) server_task = await start() async def resolve(domain): resolve, clear_cache = get_resolver(53) result = await resolve(domain, TYPES.A) await clear_cache() return result # Start a set of requests tasks = [ asyncio.create_task(resolve('www.google.com')) for _ in range(0, 100) ] await received_request.wait() # Cancel the server... server_task.cancel() # ... start a new request after_cancel_task = asyncio.create_task(resolve('www.bing.com')) # ... wait to try to ensure the request would have been received await asyncio.sleep(0.2) # ... then finally the upstream server continues with the processing # of the requests received before cancellation continue_request.set() for response in await asyncio.gather(*tasks): self.assertEqual(str(response[0]), '123.100.123.1') # ... but the request started after cancellation times out with self.assertRaises(DnsTimeout): await after_cancel_task def get_socket(port): def _get_socket(): sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) sock.setblocking(False) sock.bind(('', port)) return sock return _get_socket def get_small_socket(): # For linux, the minimum buffer size is 1024 sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) sock.setblocking(False) sock.bind(('', 53)) return sock def get_resolver(port, timeout=2.0): async def get_nameservers(_, __): for _ in range(0, 5): yield (timeout, ('127.0.0.1', port)) return Resolver(get_nameservers=get_nameservers) def get_fixed_resolver(): async def get_host(_, fqdn, qtype): hosts = { b'www.google.com': { TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0), }, } try: return hosts[fqdn.lower()][qtype] except KeyError: return None return Resolver(get_host=get_host) async def start_nameserver(port, get_response): # For some tests we need to control the responses from upstream, especially in the cases # where it's not behaving loop = asyncio.get_event_loop() sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) sock.setblocking(False) sock.bind(('', port)) async def server(): client_tasks = [] try: while True: data, addr = await recvfrom(loop, [sock], 512) client_tasks.append(asyncio.ensure_future(client_task(data, addr))) finally: for task in client_tasks: task.cancel() async def client_task(data, addr): response = await get_response(data) sock.sendto(response, addr) server_task = asyncio.ensure_future(server()) async def stop(): server_task.cancel() await asyncio.sleep(0) sock.close() return stop async def await_cancel(task): task.cancel() try: await task except asyncio.CancelledError: pass