From 17dbcfa47d0d4e4171b328128067ace4c8436115 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Wed, 11 Dec 2019 01:37:52 -0800 Subject: [PATCH] add tests/code for no answer and out of order answers --- privrdns/__init__.py | 145 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 138 insertions(+), 7 deletions(-) diff --git a/privrdns/__init__.py b/privrdns/__init__.py index 051e453..01193b8 100644 --- a/privrdns/__init__.py +++ b/privrdns/__init__.py @@ -75,16 +75,58 @@ class ServerResolver(object): self._reader, self._writer = rdrwrr + self._qs = {} + + self._replies = asyncio.create_task(self.procreplies(self._reader)) + + def __del__(self): + if self._replies: + self._replies.cancel() + self._replies = None + + self._writer.write_eof() + + def cleanup(self): + while self._qs: + i = self._qs.popitem()[1] + i.set_exception(RuntimeError( + 'server connection closed, ' + 'no response received')) + + async def procreplies(self, rdr): + while True: + # Get a response + try: + pkt = await _readpkt(self._reader) + except asyncio.streams.IncompleteReadError: + self.cleanup() + return + + resp = DNSRecord.parse(pkt) + + # fetch where we should send it + fut = self._qs.pop(resp.header.id, None) + if fut is None: + continue + + # send the result + fut.set_result(resp.rr) + async def resolve(self, q): + # make DNS packet pkt = DNSRecord(questions=[q]) pktbytes = pkt.pack() - _writepkt(self._writer, pktbytes) + # Where the procreplies function will send the answer + res = asyncio.Future() + + # Record it + self._qs[pkt.header.id] = res - pkt = await _readpkt(self._reader) - resp = DNSRecord.parse(pkt) + # Send query + _writepkt(self._writer, pktbytes) - return resp.rr + return await res class DNSProc(asyncio.DatagramProtocol): def connection_made(self, transport): @@ -165,18 +207,44 @@ class Tests(unittest.TestCase): bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ]) @async_test - async def xtest_query(self): + async def xtest_query(self): # pragma: no cover # test that the query function q = await query_record('com.', 'DS') @async_test - async def test_realresolver(self): # pragma: no cover + async def xtest_realresolver(self): # pragma: no cover # This is to test against a real resolver. host = 'gold' res = ServerResolver(await asyncio.open_connection(host, 53)) resquestion = DNSQuestion('example.com') - print(repr(await res.resolve(resquestion))) + + @async_test + async def test_resolvernoanswer(self): + client, server = _asyncsockpair() + res = ServerResolver(await client) + + rdr, wrr = await server + + # constants + resquestion = DNSQuestion('example.com') + + # start the query + ans = asyncio.create_task(res.resolve(resquestion)) + + # Fetch the question + question = await _readpkt(rdr) + dnsrec = DNSRecord.parse(question) + + # Make sure we got the correct question + self.assertEqual(dnsrec.get_q(), resquestion) + + # close the connection + wrr.write_eof() + + # make sure it raises an error + with self.assertRaises(RuntimeError): + await ans @async_test async def test_questionresolver(self): @@ -211,3 +279,66 @@ class Tests(unittest.TestCase): # veryify the received answer ans = await ans self.assertEqual(ans, resanswer) + + @async_test + async def test_outoforderanswer(self): + client, server = _asyncsockpair() + res = ServerResolver(await client) + + rdr, wrr = await server + + # constants + resquestiona = DNSQuestion('example.com') + resquestionb = DNSQuestion('example.net') + resanswera = RR.fromZone('example.com. A 192.0.2.10\nexample.com. A 192.0.2.11') + resanswerb = RR.fromZone('example.net. A 192.0.2.20\nexample.net. A 192.0.2.21') + + # start the first query + ansa = asyncio.create_task(res.resolve(resquestiona)) + + # Fetch the first question + question = await _readpkt(rdr) + dnsreca = DNSRecord.parse(question) + + # Make sure we got the correct question + self.assertEqual(dnsreca.get_q(), resquestiona) + + # start the second query (before first has been answered) + ansb = asyncio.create_task(res.resolve(resquestionb)) + + # Fetch the second question + question = await _readpkt(rdr) + dnsrecb = DNSRecord.parse(question) + + # Make sure we got the correct question + self.assertEqual(dnsrecb.get_q(), resquestionb) + + # Generate the second reply + rep = dnsrecb.reply() + rep.add_answer(*resanswerb) + repbytes = rep.pack() + + # Send the reply + _writepkt(wrr, repbytes) + + # Generate the first reply + rep = dnsreca.reply() + rep.add_answer(*resanswera) + repbytes = rep.pack() + + # Send the reply + _writepkt(wrr, repbytes) + + # Send a second reply, and make sure it is ignored + _writepkt(wrr, repbytes) + + # close the connection + wrr.write_eof() + + # veryify the first received answer + ansa = await ansa + self.assertEqual(ansa, resanswera) + + # veryify the second received answer + ansb = await ansb + self.assertEqual(ansb, resanswerb)