Browse Source

(refactor) Remove unnecessary code

main
Michal Charemza 5 years ago
parent
commit
b1293c46cb
No known key found for this signature in database GPG Key ID: 4BBAF0F6B73C4363
2 changed files with 43 additions and 59 deletions
  1. +12
    -32
      dnsrewriteproxy.py
  2. +31
    -27
      test.py

+ 12
- 32
dnsrewriteproxy.py View File

@@ -1,6 +1,5 @@
from asyncio import ( from asyncio import (
CancelledError, CancelledError,
Future,
Queue, Queue,
create_task, create_task,
get_running_loop, get_running_loop,
@@ -121,7 +120,10 @@ def DnsProxy(
async def downstream_worker(sock, downstream_queue): async def downstream_worker(sock, downstream_queue):
while True: while True:
response_data, addr = await downstream_queue.get() response_data, addr = await downstream_queue.get()
await sendto(loop, sock, response_data, addr)
try:
await sendto(sock, response_data, addr)
except Exception:
logger.exception('Unable to send response to %s', addr)
downstream_queue.task_done() downstream_queue.task_done()


async def get_response_data(resolve, request_data): async def get_response_data(resolve, request_data):
@@ -209,37 +211,15 @@ def error(query, rcode):
) )




async def sendto(loop, sock, data, addr):
async def sendto(sock, data, addr):
# In our cases, the UDP responses will always be 512 bytes or less. # In our cases, the UDP responses will always be 512 bytes or less.
# Even if sendto sent some of the data, there is no way for the other # Even if sendto sent some of the data, there is no way for the other
# end to reconstruct their order, so we don't include any logic to send # end to reconstruct their order, so we don't include any logic to send
# the rest of the data. Since it's UDP, the client already has to have # the rest of the data. Since it's UDP, the client already has to have
# retry logic

try:
return sock.sendto(data, addr)
except BlockingIOError:
pass

def writer():
try:
num_bytes = sock.sendto(data, addr)
except BlockingIOError:
pass
except BaseException as exception:
loop.remove_writer(fileno)
if not result.done():
result.set_exception(exception)
else:
loop.remove_writer(fileno)
if not result.done():
result.set_result(num_bytes)

fileno = sock.fileno()
result = Future()
loop.add_writer(fileno, writer)

try:
return await result
finally:
loop.remove_writer(fileno)
# retry logic.
#
# Potentially also, this can raise a BlockingIOError, but even trying
# to force high numbers of messages with a small socket buffer, this has
# never been observed. As above, the client must have retry logic, so we
# leave it to the client to deal with this.
return sock.sendto(data, addr)

+ 31
- 27
test.py View File

@@ -103,32 +103,10 @@ class TestProxy(unittest.TestCase):
self.assertEqual(cm.exception.args[0], 5) self.assertEqual(cm.exception.args[0], 5)


@async_test @async_test
async def test_many_of_responses_with_small_socket_buffer(self):
async def test_many_responses_with_small_socket_buffer(self):
resolve, clear_cache = get_resolver(53) resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache) self.add_async_cleanup(clear_cache)


def get_small_socket():
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2000)
sock.setblocking(False)
sock.bind(('', 53))
return sock

def get_fixed_resolver():
async def get_host(_, fqdn, qtype):
hosts = {
b'www.google.com': {
TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
},
}
try:
return hosts[fqdn.lower()][qtype]
except KeyError:
print('NONE!')
return None

return Resolver(get_host=get_host)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket, start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
get_resolver=get_fixed_resolver) get_resolver=get_fixed_resolver)
stop = await start() stop = await start()
@@ -139,11 +117,13 @@ class TestProxy(unittest.TestCase):
for _ in range(0, 100000) for _ in range(0, 100000)
] ]


responses = []
for task in tasks:
responses.append(await task)
responses = await asyncio.gather(*tasks)

for response in responses:
self.assertEqual(str(response[0]), '1.2.3.4')


self.assertEqual(str(responses[0][0]), '1.2.3.4')
bing_responses = await resolve('www.bing.com', TYPES.A)
self.assertTrue(isinstance(bing_responses[0], IPv4AddressExpiresAt))




def get_socket(port): def get_socket(port):
@@ -155,9 +135,33 @@ def get_socket(port):
return _get_socket return _get_socket




def get_small_socket():
# For linux, the minimum buffer size is 1024
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
sock.setblocking(False)
sock.bind(('', 53))
return sock


def get_resolver(port, timeout=0.5): def get_resolver(port, timeout=0.5):
async def get_nameservers(_, __): async def get_nameservers(_, __):
for _ in range(0, 5): for _ in range(0, 5):
yield (timeout, ('127.0.0.1', port)) yield (timeout, ('127.0.0.1', port))


return Resolver(get_nameservers=get_nameservers) return Resolver(get_nameservers=get_nameservers)


def get_fixed_resolver():
async def get_host(_, fqdn, qtype):
hosts = {
b'www.google.com': {
TYPES.A: IPv4AddressExpiresAt('1.2.3.4', expires_at=0),
},
}
try:
return hosts[fqdn.lower()][qtype]
except KeyError:
return None

return Resolver(get_host=get_host)

Loading…
Cancel
Save