Browse Source

add tests/code for no answer and out of order answers

main
John-Mark Gurney 4 years ago
parent
commit
17dbcfa47d
1 changed files with 138 additions and 7 deletions
  1. +138
    -7
      privrdns/__init__.py

+ 138
- 7
privrdns/__init__.py View File

@@ -75,16 +75,58 @@ class ServerResolver(object):


self._reader, self._writer = rdrwrr self._reader, self._writer = rdrwrr


self._qs = {}

self._replies = asyncio.create_task(self.procreplies(self._reader))

def __del__(self):
if self._replies:
self._replies.cancel()
self._replies = None

self._writer.write_eof()

def cleanup(self):
while self._qs:
i = self._qs.popitem()[1]
i.set_exception(RuntimeError(
'server connection closed, '
'no response received'))

async def procreplies(self, rdr):
while True:
# Get a response
try:
pkt = await _readpkt(self._reader)
except asyncio.streams.IncompleteReadError:
self.cleanup()
return

resp = DNSRecord.parse(pkt)

# fetch where we should send it
fut = self._qs.pop(resp.header.id, None)
if fut is None:
continue

# send the result
fut.set_result(resp.rr)

async def resolve(self, q): async def resolve(self, q):
# make DNS packet
pkt = DNSRecord(questions=[q]) pkt = DNSRecord(questions=[q])
pktbytes = pkt.pack() pktbytes = pkt.pack()


_writepkt(self._writer, pktbytes)
# Where the procreplies function will send the answer
res = asyncio.Future()

# Record it
self._qs[pkt.header.id] = res


pkt = await _readpkt(self._reader)
resp = DNSRecord.parse(pkt)
# Send query
_writepkt(self._writer, pktbytes)


return resp.rr
return await res


class DNSProc(asyncio.DatagramProtocol): class DNSProc(asyncio.DatagramProtocol):
def connection_made(self, transport): def connection_made(self, transport):
@@ -165,18 +207,44 @@ class Tests(unittest.TestCase):
bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ]) bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ])


@async_test @async_test
async def xtest_query(self):
async def xtest_query(self): # pragma: no cover
# test that the query function # test that the query function
q = await query_record('com.', 'DS') q = await query_record('com.', 'DS')


@async_test @async_test
async def test_realresolver(self): # pragma: no cover
async def xtest_realresolver(self): # pragma: no cover
# This is to test against a real resolver. # This is to test against a real resolver.
host = 'gold' host = 'gold'
res = ServerResolver(await asyncio.open_connection(host, 53)) res = ServerResolver(await asyncio.open_connection(host, 53))


resquestion = DNSQuestion('example.com') resquestion = DNSQuestion('example.com')
print(repr(await res.resolve(resquestion)))

@async_test
async def test_resolvernoanswer(self):
client, server = _asyncsockpair()
res = ServerResolver(await client)

rdr, wrr = await server

# constants
resquestion = DNSQuestion('example.com')

# start the query
ans = asyncio.create_task(res.resolve(resquestion))

# Fetch the question
question = await _readpkt(rdr)
dnsrec = DNSRecord.parse(question)

# Make sure we got the correct question
self.assertEqual(dnsrec.get_q(), resquestion)

# close the connection
wrr.write_eof()

# make sure it raises an error
with self.assertRaises(RuntimeError):
await ans


@async_test @async_test
async def test_questionresolver(self): async def test_questionresolver(self):
@@ -211,3 +279,66 @@ class Tests(unittest.TestCase):
# veryify the received answer # veryify the received answer
ans = await ans ans = await ans
self.assertEqual(ans, resanswer) self.assertEqual(ans, resanswer)

@async_test
async def test_outoforderanswer(self):
client, server = _asyncsockpair()
res = ServerResolver(await client)

rdr, wrr = await server

# constants
resquestiona = DNSQuestion('example.com')
resquestionb = DNSQuestion('example.net')
resanswera = RR.fromZone('example.com. A 192.0.2.10\nexample.com. A 192.0.2.11')
resanswerb = RR.fromZone('example.net. A 192.0.2.20\nexample.net. A 192.0.2.21')

# start the first query
ansa = asyncio.create_task(res.resolve(resquestiona))

# Fetch the first question
question = await _readpkt(rdr)
dnsreca = DNSRecord.parse(question)

# Make sure we got the correct question
self.assertEqual(dnsreca.get_q(), resquestiona)

# start the second query (before first has been answered)
ansb = asyncio.create_task(res.resolve(resquestionb))

# Fetch the second question
question = await _readpkt(rdr)
dnsrecb = DNSRecord.parse(question)

# Make sure we got the correct question
self.assertEqual(dnsrecb.get_q(), resquestionb)

# Generate the second reply
rep = dnsrecb.reply()
rep.add_answer(*resanswerb)
repbytes = rep.pack()

# Send the reply
_writepkt(wrr, repbytes)

# Generate the first reply
rep = dnsreca.reply()
rep.add_answer(*resanswera)
repbytes = rep.pack()

# Send the reply
_writepkt(wrr, repbytes)

# Send a second reply, and make sure it is ignored
_writepkt(wrr, repbytes)

# close the connection
wrr.write_eof()

# veryify the first received answer
ansa = await ansa
self.assertEqual(ansa, resanswera)

# veryify the second received answer
ansb = await ansb
self.assertEqual(ansb, resanswerb)

Loading…
Cancel
Save