diff --git a/test.py b/test.py index f4bcf1e..5b7ab88 100644 --- a/test.py +++ b/test.py @@ -164,7 +164,7 @@ class TestProxy(unittest.TestCase): self.assertEqual(cm.exception.args[0], 5) @async_test - async def test_sending_bad_messages_not_affect_later_queries(self): + async def test_sending_bad_messages_not_affect_later_queries_a(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) @@ -188,6 +188,31 @@ class TestProxy(unittest.TestCase): for response in responses: self.assertEqual(type(response[0]), IPv4AddressExpiresAt) + @async_test + async def test_sending_bad_messages_not_affect_later_queries_b(self): + resolve, clear_cache = get_resolver(53) + self.add_async_cleanup(clear_cache) + + start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) + stop = await start() + self.add_async_cleanup(stop) + + response = await resolve('www.google.com', TYPES.A) + self.assertEqual(type(response[0]), IPv4AddressExpiresAt) + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + for _ in range(0, 100000): + sock.sendto(b'not-a-valid-message', ('127.0.0.1', 53)) + sock.close() + + tasks = [ + asyncio.create_task(resolve('www.google.com', TYPES.A)) + for _ in range(0, 100000) + ] + responses = await asyncio.gather(*tasks) + for response in responses: + self.assertEqual(type(response[0]), IPv4AddressExpiresAt) + def get_socket(port): def _get_socket():