|
|
|
@@ -10,6 +10,7 @@ from aiodnsresolver import ( |
|
|
|
TYPES, |
|
|
|
DnsRecordDoesNotExist, |
|
|
|
DnsResponseCode, |
|
|
|
DnsTimeout, |
|
|
|
IPv4AddressExpiresAt, |
|
|
|
Message, |
|
|
|
ResourceRecord, |
|
|
|
@@ -271,6 +272,56 @@ class TestProxy(unittest.TestCase): |
|
|
|
for response in responses: |
|
|
|
self.assertEqual(type(response[0]), IPv4AddressExpiresAt) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_too_large_response_from_upstream_not_affect_later(self): |
|
|
|
num_records = 200 |
|
|
|
|
|
|
|
async def get_response(query_data): |
|
|
|
query = parse(query_data) |
|
|
|
response_records = tuple( |
|
|
|
ResourceRecord( |
|
|
|
name=query.qd[0].name, |
|
|
|
qtype=TYPES.A, |
|
|
|
qclass=1, |
|
|
|
ttl=0, |
|
|
|
rdata=ipaddress.IPv4Address('123.100.123.' + str(i)).packed, |
|
|
|
) for i in range(0, num_records) |
|
|
|
) |
|
|
|
|
|
|
|
response = Message( |
|
|
|
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, |
|
|
|
qd=query.qd, an=response_records, ns=(), ar=(), |
|
|
|
) |
|
|
|
return pack(response) |
|
|
|
|
|
|
|
stop_nameserver = await start_nameserver(54, get_response) |
|
|
|
self.add_async_cleanup(stop_nameserver) |
|
|
|
|
|
|
|
resolve, clear_cache = get_resolver(53) |
|
|
|
self.add_async_cleanup(clear_cache) |
|
|
|
|
|
|
|
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54)) |
|
|
|
stop = await start() |
|
|
|
self.add_async_cleanup(stop) |
|
|
|
|
|
|
|
tasks = [ |
|
|
|
asyncio.create_task(resolve('www.google.com', TYPES.A)) |
|
|
|
for _ in range(0, 100000) |
|
|
|
] |
|
|
|
|
|
|
|
for task in tasks: |
|
|
|
with self.assertRaises(DnsTimeout): |
|
|
|
await task |
|
|
|
|
|
|
|
num_records = 1 |
|
|
|
tasks = [ |
|
|
|
asyncio.create_task(resolve('www.google.com', TYPES.A)) |
|
|
|
for _ in range(0, 100000) |
|
|
|
] |
|
|
|
responses = await asyncio.gather(*tasks) |
|
|
|
for response in responses: |
|
|
|
self.assertEqual(str(response[0]), '123.100.123.0') |
|
|
|
|
|
|
|
|
|
|
|
def get_socket(port): |
|
|
|
def _get_socket(): |
|
|
|
|