Browse Source

add some basic code to do resolution..

main
John-Mark Gurney 5 years ago
parent
commit
3c62776db1
3 changed files with 144 additions and 9 deletions
  1. +16
    -0
      NOTES.md
  2. +126
    -8
      privrdns/__init__.py
  3. +2
    -1
      setup.py

+ 16
- 0
NOTES.md View File

@@ -1 +1,17 @@
Good DNS lib, but doesn't do DNSSEC validation:
https://github.com/paulc/dnslib https://github.com/paulc/dnslib

This looks like it supports DNSSEC:
https://github.com/rthalley/dnspython.git


Resovler class
always recursive
check cache, serve from there

ServerResolver
talks to a specific server, either tcp/udp.
Keeps connection open, for repeated queries
raises exception when connection closes,
callee deals w/ starting new session
keeps stats on performance so caller can choose best NS

+ 126
- 8
privrdns/__init__.py View File

@@ -31,20 +31,70 @@ __license__ = '2-clause BSD license'
__version__ = '0.1.0.dev' __version__ = '0.1.0.dev'


import asyncio import asyncio
import dns
import socket
import unittest import unittest


from dnslib import DNSRecord, RR, QTYPE, A, digparser
from dnslib import DNSRecord, DNSQuestion, RR, QTYPE, A, digparser
from ntunnel import async_test, parsesockstr from ntunnel import async_test, parsesockstr


class DNSCache(object):
def __init__(self):
self._data = {}

# preload some values:
self._data.update({
('.', 'DS'): [ ('.', 'IN', 'DS', 20326, 8, 2,
bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D'))],
})

def get(self, name, rtype):
return self._data[(name, rtype)]

class Resolver(object):
'''Resolve DNS names to records.

Either serve it from the cache, or do a query. If the query
can be DNSSEC validated, do it via the privacy querier if not,
via the trusted one.'''

def _writepkt(wrr, bts):
wrr.write(len(bts).to_bytes(2, 'big'))
wrr.write(bts)

async def _readpkt(rdr):
nbytes = int.from_bytes(await rdr.readexactly(2), 'big')
return await rdr.readexactly(nbytes)

class ServerResolver(object):
'''Ask a specific server a question, and return the result.'''

def __init__(self, rdrwrr):
'''Pass in the reader and writer pair that is connected
to the server.'''

self._reader, self._writer = rdrwrr

async def resolve(self, q):
pkt = DNSRecord(questions=[q])
pktbytes = pkt.pack()

_writepkt(self._writer, pktbytes)

pkt = await _readpkt(self._reader)
resp = DNSRecord.parse(pkt)

return resp.rr

class DNSProc(asyncio.DatagramProtocol): class DNSProc(asyncio.DatagramProtocol):
def connection_made(self, transport): def connection_made(self, transport):
print('cm')
#print('cm')
self.transport = transport self.transport = transport


def datagram_received(self, data, addr): def datagram_received(self, data, addr):
print('dr')
#print('dr')
pkt = DNSRecord.parse(data) pkt = DNSRecord.parse(data)
print(repr((pkt, addr)))
#print(repr((pkt, addr)))


d = pkt.reply() d = pkt.reply()
d.add_answer(RR("xxx.abc.com",QTYPE.A,rdata=A("1.2.3.4"))) d.add_answer(RR("xxx.abc.com",QTYPE.A,rdata=A("1.2.3.4")))
@@ -56,13 +106,23 @@ async def dnsprocessor(sockstr):


if proto == 'udp': if proto == 'udp':
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
print('pre-cde')
#print('pre-cde')
trans, protocol = await loop.create_datagram_endpoint(DNSProc, trans, protocol = await loop.create_datagram_endpoint(DNSProc,
local_addr=(args.get('host', '127.0.0.1'), args['port'])) local_addr=(args.get('host', '127.0.0.1'), args['port']))
print('post-cde', repr((trans, protocol)), trans is protocol.transport)
#print('post-cde', repr((trans, protocol)), trans is protocol.transport)
else: else:
raise ValueError('unknown protocol: %s' % repr(proto)) raise ValueError('unknown protocol: %s' % repr(proto))


def _asyncsockpair():
'''Create a pair of sockets that are bound to each other.
The function will return a tuple of two coroutine's, that
each, when await'ed upon, will return the reader/writer pair.'''

socka, sockb = socket.socketpair()

return (asyncio.open_connection(sock=socka),
asyncio.open_connection(sock=sockb))

class Tests(unittest.TestCase): class Tests(unittest.TestCase):
@async_test @async_test
async def test_processdnsfailures(self): async def test_processdnsfailures(self):
@@ -70,7 +130,7 @@ class Tests(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
dnsproc = await dnsprocessor( dnsproc = await dnsprocessor(
'tcp:host=127.0.0.1,port=%d' % port) 'tcp:host=127.0.0.1,port=%d' % port)
print('post-dns')
#print('post-dns')


@async_test @async_test
async def test_processdns(self): async def test_processdns(self):
@@ -92,4 +152,62 @@ class Tests(unittest.TestCase):


rep = list(digparser.DigParser(stdout)) rep = list(digparser.DigParser(stdout))
self.assertEqual(len(rep), 1) self.assertEqual(len(rep), 1)
print('x', repr(list(rep)))
#print('x', repr(list(rep)))

def test_cache(self):
# test that the cache
cache = DNSCache()

# has the root trust anchors in it
dsrs = cache.get('.', 'DS')

self.assertEqual(dsrs, [ ('.', 'IN', 'DS', 20326, 8, 2,
bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ])

@async_test
async def xtest_query(self):
# test that the query function
q = await query_record('com.', 'DS')

@async_test
async def test_realresolver(self): # pragma: no cover
# This is to test against a real resolver.
host = 'gold'
res = ServerResolver(await asyncio.open_connection(host, 53))

resquestion = DNSQuestion('example.com')
print(repr(await res.resolve(resquestion)))

@async_test
async def test_questionresolver(self):
client, server = _asyncsockpair()
res = ServerResolver(await client)

rdr, wrr = await server

# constants
resquestion = DNSQuestion('example.com')
resanswer = RR.fromZone('example.com. A 192.0.2.10\nexample.com. A 192.0.2.11')

# start the query
ans = asyncio.create_task(res.resolve(resquestion))

# Fetch the question
question = await _readpkt(rdr)
dnsrec = DNSRecord.parse(question)

# Make sure we got the correct question
self.assertEqual(dnsrec.get_q(), resquestion)

# Generate the reply
rep = dnsrec.reply()
rep.add_answer(*resanswer)
repbytes = rep.pack()

# Send the reply
_writepkt(wrr, repbytes)
wrr.write_eof()

# veryify the received answer
ans = await ans
self.assertEqual(ans, resanswer)

+ 2
- 1
setup.py View File

@@ -34,7 +34,8 @@ setup(name='privrdns',
python_requires='~=3.7', python_requires='~=3.7',
install_requires=[ install_requires=[
'ntunnel @ git+https://www.funkthat.com/gitea/jmg/ntunnel.git@c203547c28e935d11855601ce4e4f31db4e9065d', 'ntunnel @ git+https://www.funkthat.com/gitea/jmg/ntunnel.git@c203547c28e935d11855601ce4e4f31db4e9065d',
'dnslib @ git+https://github.com/paulc/dnslib.git'
'dnspython @ git+https://github.com/rthalley/dnspython.git',
'dnslib @ git+https://github.com/paulc/dnslib.git',
], ],
extras_require = { extras_require = {
'dev': [ 'coverage' ], 'dev': [ 'coverage' ],


Loading…
Cancel
Save