Browse Source

(feat) Expose server task

main
Michal Charemza 5 years ago
parent
commit
fa3725c34a
No known key found for this signature in database GPG Key ID: 4BBAF0F6B73C4363
3 changed files with 152 additions and 48 deletions
  1. +45
    -5
      README.md
  2. +4
    -9
      dnsrewriteproxy.py
  3. +103
    -34
      test.py

+ 45
- 5
README.md View File

@@ -15,11 +15,8 @@ from dnsrewriteproxy import DnsProxy
# Proxy all incoming A record requests without any rewriting
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))

# Proxy is running, accepting UDP requests on port 53
stop = await start()

# Stopped
await stop()
# Run proxy, accepting UDP requests on port 53
await start()
```

The `rules` parameter must be an iterable [e.g. a list or a tuple] of tuples, where each tuple is regex pattern/replacement pair, passed to [re.subn](https://docs.python.org/3/library/re.html#re.subn) under the hood. On each incoming DNS request from downstream for a domain
@@ -46,3 +43,46 @@ start = DnsProxy(rules=(
(r'(^.*$)', r'\1'),
))
```


## Server lifecycle

In the above example `await start()` completes just after the server has started listening. The coroutine `start` returns the underlying _task_ to give control over the server lifecycle. A task can be seen as an "asyncio thread"; this is exposed to allow the server to sit in a larger asyncio Python program that may have a specific startup/shutdown procedure.


### Run forever

You can run the server forever [or until it hits some non-recoverable error] by awaiting this task.

```python
from dnsrewriteproxy import DnsProxy

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
server_task = await start()

# Waiting here until the server is stopped
await server_task
```


### Stopping the server

To stop the server, you can `cancel` the returned task.

```python
from dnsrewriteproxy import DnsProxy

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
proxy_task = await start()

# ... Receive requests

# Initiate stopping: new requests will not be processed...
proxy_task.cancel()

try:
# ... and we wait until previously received requests have been processed
await proxy_task
except asyncio.CancelledError:
pass
```

+ 4
- 9
dnsrewriteproxy.py View File

@@ -59,7 +59,7 @@ def DnsProxy(
# them in a queue that is then fetched from and processed by the proxy
# workers

async def server_worker(sock, resolve):
async def server_worker(sock, resolve, stop):
upstream_queue = Queue(maxsize=num_workers)

# We have multiple upstream workers to be able to send multiple
@@ -84,6 +84,8 @@ def DnsProxy(
except CancelledError:
pass

await stop()

async def upstream_worker(sock, resolve, upstream_queue):
while True:
request_data, addr = await upstream_queue.get()
@@ -158,19 +160,12 @@ def DnsProxy(
# /etc/hosts or /etc/resolve.conf, and can raise an exception if
# something goes wrong with that
resolve, clear_cache = get_resolver()
server_worker_task = create_task(server_worker(sock, resolve))

async def stop():
server_worker_task.cancel()
try:
await server_worker_task
except CancelledError:
pass

sock.close()
await clear_cache()

return stop
return create_task(server_worker(sock, resolve, stop))

return start



+ 103
- 34
test.py View File

@@ -35,16 +35,16 @@ def async_test(func):


class TestProxy(unittest.TestCase):
def add_async_cleanup(self, coroutine):
self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())
def add_async_cleanup(self, coroutine, *args):
self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine(*args))

@async_test
async def test_e2e_no_match_rule(self):
resolve, clear_cache = get_resolver(3535)
self.add_async_cleanup(clear_cache)
start = DnsProxy(get_socket=get_socket(3535))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

with self.assertRaises(DnsResponseCode) as cm:
await resolve('www.google.com', TYPES.A)
@@ -56,8 +56,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(3535)
self.add_async_cleanup(clear_cache)
start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('www.google.com', TYPES.A)

@@ -68,8 +68,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('www.google.com', TYPES.A)

@@ -80,8 +80,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

with self.assertRaises(DnsRecordDoesNotExist):
await resolve('doesnotexist.charemza.name', TYPES.A)
@@ -91,8 +91,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy(rules=((r'^doesnotexist\.charemza\.name$', r'www.google.com'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('doesnotexist.charemza.name', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -103,8 +103,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

with self.assertRaises(DnsResponseCode) as cm:
await resolve('www.google.com', TYPES.A)
@@ -116,8 +116,8 @@ class TestProxy(unittest.TestCase):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)
start = DnsProxy()
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

with self.assertRaises(DnsResponseCode) as cm:
await resolve('doesnotexist.charemza.name', TYPES.A)
@@ -131,8 +131,8 @@ class TestProxy(unittest.TestCase):

start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket,
get_resolver=get_fixed_resolver)
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
@@ -150,8 +150,8 @@ class TestProxy(unittest.TestCase):
@async_test
async def test_many_responses_with_small_socket_buffer_onward_query(self):
start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_socket=get_small_socket)
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

async def resolve(domain):
resolve, clear_cache = get_resolver(53)
@@ -178,8 +178,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
@@ -213,8 +213,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

with self.assertRaises(DnsResponseCode) as cm:
await resolve('www.google.com', TYPES.A)
@@ -233,8 +233,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -258,8 +258,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -283,8 +283,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -310,8 +310,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)
@@ -387,8 +387,8 @@ class TestProxy(unittest.TestCase):
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
stop = await start()
self.add_async_cleanup(stop)
server_task = await start()
self.add_async_cleanup(await_cancel, server_task)

tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
@@ -408,6 +408,67 @@ class TestProxy(unittest.TestCase):
for response in responses:
self.assertEqual(str(response[0]), '123.100.123.0')

@async_test
async def test_server_response_after_cancel_returned_to_client(self):
received_request = asyncio.Event()
continue_request = asyncio.Event()

async def get_response(query_data):
query = parse(query_data)
response_record = ResourceRecord(
name=query.qd[0].name,
qtype=TYPES.A,
qclass=1,
ttl=0,
rdata=ipaddress.IPv4Address('123.100.123.1').packed,
)

response = Message(
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
qd=query.qd, an=(response_record,), ns=(), ar=(),
)
received_request.set()
await continue_request.wait()
return pack(response)

stop_nameserver = await start_nameserver(54, get_response)
self.add_async_cleanup(stop_nameserver)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),), get_resolver=lambda: get_resolver(54))
server_task = await start()

async def resolve(domain):
resolve, clear_cache = get_resolver(53)
result = await resolve(domain, TYPES.A)
await clear_cache()
return result

# Start a set of requests
tasks = [
asyncio.create_task(resolve('www.google.com'))
for _ in range(0, 1000)
]
await received_request.wait()

# Cancel the server...
server_task.cancel()

# ... start a new request
after_cancel_task = asyncio.create_task(resolve('www.bing.com'))

# ... wait 5 seconds
await asyncio.sleep(0.5)

# ... then finally the upstream server continues with the processing
# of the requests received before cancellation
continue_request.set()
for response in await asyncio.gather(*tasks):
self.assertEqual(str(response[0]), '123.100.123.1')

# ... but the request started after cancellation times out
with self.assertRaises(DnsTimeout):
await after_cancel_task


def get_socket(port):
def _get_socket():
@@ -481,3 +542,11 @@ async def start_nameserver(port, get_response):
sock.close()

return stop


async def await_cancel(task):
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

Loading…
Cancel
Save