@@ -35,16 +35,16 @@ def async_test(func):
class TestProxy(unittest.TestCase):
def add_async_cleanup(self, coroutine):
self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())
def add_async_cleanup(self, coroutine, *args ):
self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine(*args ))
@async_test
async def test_e2e_no_match_rule(self):
resolve, clear_cache = get_resolver(3535)
self.add_async_cleanup(clear_cache)
start = DnsProxy(get_socket=get_socket(3535))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
with self.assertRaises(DnsResponseCode) as cm:
await resolve('www.google.com', TYPES.A)
@@ -56,8 +56,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(3535)
self.add_async_cleanup(clear_cache)
start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('www.google.com', TYPES.A)
@@ -68,8 +68,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('www.google.com', TYPES.A)
@@ -80,8 +80,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
with self.assertRaises(DnsRecordDoesNotExist):
await resolve('doesnotexist.charemza.name', TYPES.A)
@@ -91,8 +91,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'^doesnotexist\.charemza\.name$', r'www.google.com'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('doesnotexist.charemza.name', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -103,8 +103,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
with self.assertRaises(DnsResponseCode) as cm:
await resolve('www.google.com', TYPES.A)
@@ -116,8 +116,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy()
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
with self.assertRaises(DnsResponseCode) as cm:
await resolve('doesnotexist.charemza.name', TYPES.A)
@@ -131,8 +131,8 @@ class TestProxy(unittest.TestCase):
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
get_resolver=get_fixed_resolver)
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
@@ -150,8 +150,8 @@ class TestProxy(unittest.TestCase):
@async_test
async def test_many_responses_with_small_socket_buffer_onward_query(self):
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket)
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
async def resolve(domain):
resolve, clear_cache = get_resolver(53)
@@ -178,8 +178,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
@@ -213,8 +213,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
with self.assertRaises(DnsResponseCode) as cm:
await resolve('www.google.com', TYPES.A)
@@ -233,8 +233,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -258,8 +258,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -283,8 +283,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -310,8 +310,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -387,8 +387,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
stop = await start()
self.add_async_cleanup(stop )
server_task = await start()
self.add_async_cleanup(await_cancel, server_task )
tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
@@ -408,6 +408,67 @@ class TestProxy(unittest.TestCase):
for response in responses:
self.assertEqual(str(response[0]), '123.100.123.0')
@async_test
async def test_server_response_after_cancel_returned_to_client(self):
received_request = asyncio.Event()
continue_request = asyncio.Event()
async def get_response(query_data):
query = parse(query_data)
response_record = ResourceRecord(
name=query.qd[0].name,
qtype=TYPES.A,
qclass=1,
ttl=0,
rdata=ipaddress.IPv4Address('123.100.123.1').packed,
)
response = Message(
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
qd=query.qd, an=(response_record,), ns=(), ar=(),
)
received_request.set()
await continue_request.wait()
return pack(response)
stop_nameserver = await start_nameserver(54, get_response)
self.add_async_cleanup(stop_nameserver)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
server_task = await start()
async def resolve(domain):
resolve, clear_cache = get_resolver(53)
result = await resolve(domain, TYPES.A)
await clear_cache()
return result
# Start a set of requests
tasks = [
asyncio.create_task(resolve('www.google.com'))
for _ in range(0, 1000)
]
await received_request.wait()
# Cancel the server...
server_task.cancel()
# ... start a new request
after_cancel_task = asyncio.create_task(resolve('www.bing.com'))
# ... wait 5 seconds
await asyncio.sleep(0.5)
# ... then finally the upstream server continues with the processing
# of the requests received before cancellation
continue_request.set()
for response in await asyncio.gather(*tasks):
self.assertEqual(str(response[0]), '123.100.123.1')
# ... but the request started after cancellation times out
with self.assertRaises(DnsTimeout):
await after_cancel_task
def get_socket(port):
def _get_socket():
@@ -481,3 +542,11 @@ async def start_nameserver(port, get_response):
sock.close()
return stop
async def await_cancel(task):
task.cancel()
try:
await task
except asyncio.CancelledError:
pass