From b1293c46cb642428feed7b108970bf7f7c80442f Mon Sep 17 00:00:00 2001 From: Michal Charemza Date: Fri, 17 Jan 2020 09:45:03 +0000 Subject: [PATCH] (refactor) Remove unnecessary code --- dnsrewriteproxy.py | 44 ++++++++++------------------------- test.py | 58 +++++++++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 59 deletions(-) diff --git a/dnsrewriteproxy.py b/dnsrewriteproxy.py index 12f4224..e652f3b 100644 --- a/dnsrewriteproxy.py +++ b/dnsrewriteproxy.py @@ -1,6 +1,5 @@ from asyncio import ( CancelledError, - Future, Queue, create_task, get_running_loop, @@ -121,7 +120,10 @@ def DnsProxy( async def downstream_worker(sock, downstream_queue): while True: response_data, addr = await downstream_queue.get() - await sendto(loop, sock, response_data, addr) + try: + await sendto(sock, response_data, addr) + except Exception: + logger.exception('Unable to send response to %s', addr) downstream_queue.task_done() async def get_response_data(resolve, request_data): @@ -209,37 +211,15 @@ def error(query, rcode): ) -async def sendto(loop, sock, data, addr): +async def sendto(sock, data, addr): # In our cases, the UDP responses will always be 512 bytes or less. # Even if sendto sent some of the data, there is no way for the other # end to reconstruct their order, so we don't include any logic to send # the rest of the data. Since it's UDP, the client already has to have - # retry logic - - try: - return sock.sendto(data, addr) - except BlockingIOError: - pass - - def writer(): - try: - num_bytes = sock.sendto(data, addr) - except BlockingIOError: - pass - except BaseException as exception: - loop.remove_writer(fileno) - if not result.done(): - result.set_exception(exception) - else: - loop.remove_writer(fileno) - if not result.done(): - result.set_result(num_bytes) - - fileno = sock.fileno() - result = Future() - loop.add_writer(fileno, writer) - - try: - return await result - finally: - loop.remove_writer(fileno) + # retry logic. + # + # Potentially also, this can raise a BlockingIOError, but even trying + # to force high numbers of messages with a small socket buffer, this has + # never been observed. As above, the client must have retry logic, so we + # leave it to the client to deal with this. + return sock.sendto(data, addr) diff --git a/test.py b/test.py index a7b395d..d331cc4 100644 --- a/test.py +++ b/test.py @@ -103,32 +103,10 @@ class TestProxy(unittest.TestCase): self.assertEqual(cm.exception.args[0], 5) @async_test - async def test_many_of_responses_with_small_socket_buffer(self): + async def test_many_responses_with_small_socket_buffer(self): resolve, clear_cache = get_resolver(53) self.add_async_cleanup(clear_cache) - def get_small_socket(): - sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2000) - sock.setblocking(False) - sock.bind(('', 53)) - return sock - - def get_fixed_resolver(): - async def get_host(_, fqdn, qtype): - hosts = { - b'www.google.com': { - TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0), - }, - } - try: - return hosts[fqdn.lower()][qtype] - except KeyError: - print('NONE!') - return None - - return Resolver(get_host=get_host) - start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket, get_resolver=get_fixed_resolver) stop = await start() @@ -139,11 +117,13 @@ class TestProxy(unittest.TestCase): for _ in range(0, 100000) ] - responses = [] - for task in tasks: - responses.append(await task) + responses = await asyncio.gather(*tasks) + + for response in responses: + self.assertEqual(str(response[0]), '1.2.3.4') - self.assertEqual(str(responses[0][0]), '1.2.3.4') + bing_responses = await resolve('www.bing.com', TYPES.A) + self.assertTrue(isinstance(bing_responses[0], IPv4AddressExpiresAt)) def get_socket(port): @@ -155,9 +135,33 @@ def get_socket(port): return _get_socket +def get_small_socket(): + # For linux, the minimum buffer size is 1024 + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setblocking(False) + sock.bind(('', 53)) + return sock + + def get_resolver(port, timeout=0.5): async def get_nameservers(_, __): for _ in range(0, 5): yield (timeout, ('127.0.0.1', port)) return Resolver(get_nameservers=get_nameservers) + + +def get_fixed_resolver(): + async def get_host(_, fqdn, qtype): + hosts = { + b'www.google.com': { + TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0), + }, + } + try: + return hosts[fqdn.lower()][qtype] + except KeyError: + return None + + return Resolver(get_host=get_host)