@@ -5,10 +5,13 @@ A DNS proxy server that conditionally rewrites and filters A record requests | |||||
## Usage | ## Usage | ||||
By default the proxy will listen on port 53, and proxy requests to the servers in `/etc/resolve.conf`. However, by default all requests are blocked without explicit rules, so to proxy requests you must configure at least one rewrite rule. | |||||
```python | ```python | ||||
from dnsrewriteproxy import DnsProxy | from dnsrewriteproxy import DnsProxy | ||||
start = DnsProxy() | |||||
# Proxy all incoming A record requests without any rewriting | |||||
start = DnsProxy(rules=((r'(^.*$)', r'\1'),)) | |||||
# Proxy is running, accepting UDP requests on port 53 | # Proxy is running, accepting UDP requests on port 53 | ||||
stop = await start() | stop = await start() | ||||
@@ -19,6 +19,7 @@ from enum import ( | |||||
IntEnum, | IntEnum, | ||||
) | ) | ||||
import logging | import logging | ||||
import re | |||||
import socket | import socket | ||||
from aiodnsresolver import ( | from aiodnsresolver import ( | ||||
@@ -61,6 +62,7 @@ def DnsProxy( | |||||
get_resolver=get_resolver_default, get_logger=get_logger_default, | get_resolver=get_resolver_default, get_logger=get_logger_default, | ||||
get_socket=get_socket_default, | get_socket=get_socket_default, | ||||
num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000, | num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000, | ||||
rules=(), | |||||
): | ): | ||||
class ERRORS(IntEnum): | class ERRORS(IntEnum): | ||||
@@ -165,8 +167,16 @@ def DnsProxy( | |||||
name_bytes = query.qd[0].name | name_bytes = query.qd[0].name | ||||
name_str = query.qd[0].name.decode('idna') | name_str = query.qd[0].name.decode('idna') | ||||
for pattern, replace in rules: | |||||
rewritten_name_str, num_matches = re.subn(pattern, replace, name_str) | |||||
if num_matches: | |||||
break | |||||
else: | |||||
# No break was triggered, i.e. no match | |||||
return error(query, ERRORS.REFUSED) | |||||
try: | try: | ||||
ip_addresses = await resolve(name_str, TYPES.A) | |||||
ip_addresses = await resolve(rewritten_name_str, TYPES.A) | |||||
except DnsRecordDoesNotExist: | except DnsRecordDoesNotExist: | ||||
return error(query, ERRORS.NXDOMAIN) | return error(query, ERRORS.NXDOMAIN) | ||||
except DnsResponseCode as dns_response_code_error: | except DnsResponseCode as dns_response_code_error: | ||||
@@ -7,6 +7,7 @@ from aiodnsresolver import ( | |||||
TYPES, | TYPES, | ||||
Resolver, | Resolver, | ||||
IPv4AddressExpiresAt, | IPv4AddressExpiresAt, | ||||
DnsResponseCode, | |||||
) | ) | ||||
from dnsrewriteproxy import ( | from dnsrewriteproxy import ( | ||||
DnsProxy, | DnsProxy, | ||||
@@ -26,23 +27,39 @@ class TestProxy(unittest.TestCase): | |||||
self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine()) | self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine()) | ||||
@async_test | @async_test | ||||
async def test_e2e(self): | |||||
def get_socket(): | |||||
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) | |||||
sock.setblocking(False) | |||||
sock.bind(('', 3535)) | |||||
return sock | |||||
async def test_e2e_no_match_rule(self): | |||||
resolve, clear_cache = get_resolver() | |||||
self.add_async_cleanup(clear_cache) | |||||
start = DnsProxy(get_socket=get_socket) | |||||
stop = await start() | |||||
self.add_async_cleanup(stop) | |||||
async def get_nameservers(_, __): | |||||
for _ in range(0, 5): | |||||
yield (0.5, ('127.0.0.1', 3535)) | |||||
with self.assertRaises(DnsResponseCode): | |||||
await resolve('www.google.com', TYPES.A) | |||||
resolve, clear_cache = Resolver(get_nameservers=get_nameservers) | |||||
@async_test | |||||
async def test_e2e_match_all(self): | |||||
resolve, clear_cache = get_resolver() | |||||
self.add_async_cleanup(clear_cache) | self.add_async_cleanup(clear_cache) | ||||
start = DnsProxy(get_socket=get_socket) | |||||
start = DnsProxy(get_socket=get_socket, rules=((r'(^.*$)', r'\1'),)) | |||||
stop = await start() | stop = await start() | ||||
self.add_async_cleanup(stop) | self.add_async_cleanup(stop) | ||||
response = await resolve('www.google.com', TYPES.A) | response = await resolve('www.google.com', TYPES.A) | ||||
self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt)) | self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt)) | ||||
def get_socket(): | |||||
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) | |||||
sock.setblocking(False) | |||||
sock.bind(('', 3535)) | |||||
return sock | |||||
def get_resolver(): | |||||
async def get_nameservers(_, __): | |||||
for _ in range(0, 5): | |||||
yield (0.5, ('127.0.0.1', 3535)) | |||||
return Resolver(get_nameservers=get_nameservers) |