|
|
@@ -31,20 +31,70 @@ __license__ = '2-clause BSD license' |
|
|
|
__version__ = '0.1.0.dev' |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import dns |
|
|
|
import socket |
|
|
|
import unittest |
|
|
|
|
|
|
|
from dnslib import DNSRecord, RR, QTYPE, A, digparser |
|
|
|
from dnslib import DNSRecord, DNSQuestion, RR, QTYPE, A, digparser |
|
|
|
from ntunnel import async_test, parsesockstr |
|
|
|
|
|
|
|
class DNSCache(object): |
|
|
|
def __init__(self): |
|
|
|
self._data = {} |
|
|
|
|
|
|
|
# preload some values: |
|
|
|
self._data.update({ |
|
|
|
('.', 'DS'): [ ('.', 'IN', 'DS', 20326, 8, 2, |
|
|
|
bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D'))], |
|
|
|
}) |
|
|
|
|
|
|
|
def get(self, name, rtype): |
|
|
|
return self._data[(name, rtype)] |
|
|
|
|
|
|
|
class Resolver(object): |
|
|
|
'''Resolve DNS names to records. |
|
|
|
|
|
|
|
Either serve it from the cache, or do a query. If the query |
|
|
|
can be DNSSEC validated, do it via the privacy querier if not, |
|
|
|
via the trusted one.''' |
|
|
|
|
|
|
|
def _writepkt(wrr, bts): |
|
|
|
wrr.write(len(bts).to_bytes(2, 'big')) |
|
|
|
wrr.write(bts) |
|
|
|
|
|
|
|
async def _readpkt(rdr): |
|
|
|
nbytes = int.from_bytes(await rdr.readexactly(2), 'big') |
|
|
|
return await rdr.readexactly(nbytes) |
|
|
|
|
|
|
|
class ServerResolver(object): |
|
|
|
'''Ask a specific server a question, and return the result.''' |
|
|
|
|
|
|
|
def __init__(self, rdrwrr): |
|
|
|
'''Pass in the reader and writer pair that is connected |
|
|
|
to the server.''' |
|
|
|
|
|
|
|
self._reader, self._writer = rdrwrr |
|
|
|
|
|
|
|
async def resolve(self, q): |
|
|
|
pkt = DNSRecord(questions=[q]) |
|
|
|
pktbytes = pkt.pack() |
|
|
|
|
|
|
|
_writepkt(self._writer, pktbytes) |
|
|
|
|
|
|
|
pkt = await _readpkt(self._reader) |
|
|
|
resp = DNSRecord.parse(pkt) |
|
|
|
|
|
|
|
return resp.rr |
|
|
|
|
|
|
|
class DNSProc(asyncio.DatagramProtocol): |
|
|
|
def connection_made(self, transport): |
|
|
|
print('cm') |
|
|
|
#print('cm') |
|
|
|
self.transport = transport |
|
|
|
|
|
|
|
def datagram_received(self, data, addr): |
|
|
|
print('dr') |
|
|
|
#print('dr') |
|
|
|
pkt = DNSRecord.parse(data) |
|
|
|
print(repr((pkt, addr))) |
|
|
|
#print(repr((pkt, addr))) |
|
|
|
|
|
|
|
d = pkt.reply() |
|
|
|
d.add_answer(RR("xxx.abc.com",QTYPE.A,rdata=A("1.2.3.4"))) |
|
|
@@ -56,13 +106,23 @@ async def dnsprocessor(sockstr): |
|
|
|
|
|
|
|
if proto == 'udp': |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
print('pre-cde') |
|
|
|
#print('pre-cde') |
|
|
|
trans, protocol = await loop.create_datagram_endpoint(DNSProc, |
|
|
|
local_addr=(args.get('host', '127.0.0.1'), args['port'])) |
|
|
|
print('post-cde', repr((trans, protocol)), trans is protocol.transport) |
|
|
|
#print('post-cde', repr((trans, protocol)), trans is protocol.transport) |
|
|
|
else: |
|
|
|
raise ValueError('unknown protocol: %s' % repr(proto)) |
|
|
|
|
|
|
|
def _asyncsockpair(): |
|
|
|
'''Create a pair of sockets that are bound to each other. |
|
|
|
The function will return a tuple of two coroutine's, that |
|
|
|
each, when await'ed upon, will return the reader/writer pair.''' |
|
|
|
|
|
|
|
socka, sockb = socket.socketpair() |
|
|
|
|
|
|
|
return (asyncio.open_connection(sock=socka), |
|
|
|
asyncio.open_connection(sock=sockb)) |
|
|
|
|
|
|
|
class Tests(unittest.TestCase): |
|
|
|
@async_test |
|
|
|
async def test_processdnsfailures(self): |
|
|
@@ -70,7 +130,7 @@ class Tests(unittest.TestCase): |
|
|
|
with self.assertRaises(ValueError): |
|
|
|
dnsproc = await dnsprocessor( |
|
|
|
'tcp:host=127.0.0.1,port=%d' % port) |
|
|
|
print('post-dns') |
|
|
|
#print('post-dns') |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_processdns(self): |
|
|
@@ -92,4 +152,62 @@ class Tests(unittest.TestCase): |
|
|
|
|
|
|
|
rep = list(digparser.DigParser(stdout)) |
|
|
|
self.assertEqual(len(rep), 1) |
|
|
|
print('x', repr(list(rep))) |
|
|
|
#print('x', repr(list(rep))) |
|
|
|
|
|
|
|
def test_cache(self): |
|
|
|
# test that the cache |
|
|
|
cache = DNSCache() |
|
|
|
|
|
|
|
# has the root trust anchors in it |
|
|
|
dsrs = cache.get('.', 'DS') |
|
|
|
|
|
|
|
self.assertEqual(dsrs, [ ('.', 'IN', 'DS', 20326, 8, 2, |
|
|
|
bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ]) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def xtest_query(self): |
|
|
|
# test that the query function |
|
|
|
q = await query_record('com.', 'DS') |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_realresolver(self): # pragma: no cover |
|
|
|
# This is to test against a real resolver. |
|
|
|
host = 'gold' |
|
|
|
res = ServerResolver(await asyncio.open_connection(host, 53)) |
|
|
|
|
|
|
|
resquestion = DNSQuestion('example.com') |
|
|
|
print(repr(await res.resolve(resquestion))) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_questionresolver(self): |
|
|
|
client, server = _asyncsockpair() |
|
|
|
res = ServerResolver(await client) |
|
|
|
|
|
|
|
rdr, wrr = await server |
|
|
|
|
|
|
|
# constants |
|
|
|
resquestion = DNSQuestion('example.com') |
|
|
|
resanswer = RR.fromZone('example.com. A 192.0.2.10\nexample.com. A 192.0.2.11') |
|
|
|
|
|
|
|
# start the query |
|
|
|
ans = asyncio.create_task(res.resolve(resquestion)) |
|
|
|
|
|
|
|
# Fetch the question |
|
|
|
question = await _readpkt(rdr) |
|
|
|
dnsrec = DNSRecord.parse(question) |
|
|
|
|
|
|
|
# Make sure we got the correct question |
|
|
|
self.assertEqual(dnsrec.get_q(), resquestion) |
|
|
|
|
|
|
|
# Generate the reply |
|
|
|
rep = dnsrec.reply() |
|
|
|
rep.add_answer(*resanswer) |
|
|
|
repbytes = rep.pack() |
|
|
|
|
|
|
|
# Send the reply |
|
|
|
_writepkt(wrr, repbytes) |
|
|
|
wrr.write_eof() |
|
|
|
|
|
|
|
# veryify the received answer |
|
|
|
ans = await ans |
|
|
|
self.assertEqual(ans, resanswer) |