| @@ -75,16 +75,58 @@ class ServerResolver(object): | |||
| self._reader, self._writer = rdrwrr | |||
| self._qs = {} | |||
| self._replies = asyncio.create_task(self.procreplies(self._reader)) | |||
| def __del__(self): | |||
| if self._replies: | |||
| self._replies.cancel() | |||
| self._replies = None | |||
| self._writer.write_eof() | |||
| def cleanup(self): | |||
| while self._qs: | |||
| i = self._qs.popitem()[1] | |||
| i.set_exception(RuntimeError( | |||
| 'server connection closed, ' | |||
| 'no response received')) | |||
| async def procreplies(self, rdr): | |||
| while True: | |||
| # Get a response | |||
| try: | |||
| pkt = await _readpkt(self._reader) | |||
| except asyncio.streams.IncompleteReadError: | |||
| self.cleanup() | |||
| return | |||
| resp = DNSRecord.parse(pkt) | |||
| # fetch where we should send it | |||
| fut = self._qs.pop(resp.header.id, None) | |||
| if fut is None: | |||
| continue | |||
| # send the result | |||
| fut.set_result(resp.rr) | |||
| async def resolve(self, q): | |||
| # make DNS packet | |||
| pkt = DNSRecord(questions=[q]) | |||
| pktbytes = pkt.pack() | |||
| _writepkt(self._writer, pktbytes) | |||
| # Where the procreplies function will send the answer | |||
| res = asyncio.Future() | |||
| # Record it | |||
| self._qs[pkt.header.id] = res | |||
| pkt = await _readpkt(self._reader) | |||
| resp = DNSRecord.parse(pkt) | |||
| # Send query | |||
| _writepkt(self._writer, pktbytes) | |||
| return resp.rr | |||
| return await res | |||
| class DNSProc(asyncio.DatagramProtocol): | |||
| def connection_made(self, transport): | |||
| @@ -165,18 +207,44 @@ class Tests(unittest.TestCase): | |||
| bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ]) | |||
| @async_test | |||
| async def xtest_query(self): | |||
| async def xtest_query(self): # pragma: no cover | |||
| # test that the query function | |||
| q = await query_record('com.', 'DS') | |||
| @async_test | |||
| async def test_realresolver(self): # pragma: no cover | |||
| async def xtest_realresolver(self): # pragma: no cover | |||
| # This is to test against a real resolver. | |||
| host = 'gold' | |||
| res = ServerResolver(await asyncio.open_connection(host, 53)) | |||
| resquestion = DNSQuestion('example.com') | |||
| print(repr(await res.resolve(resquestion))) | |||
| @async_test | |||
| async def test_resolvernoanswer(self): | |||
| client, server = _asyncsockpair() | |||
| res = ServerResolver(await client) | |||
| rdr, wrr = await server | |||
| # constants | |||
| resquestion = DNSQuestion('example.com') | |||
| # start the query | |||
| ans = asyncio.create_task(res.resolve(resquestion)) | |||
| # Fetch the question | |||
| question = await _readpkt(rdr) | |||
| dnsrec = DNSRecord.parse(question) | |||
| # Make sure we got the correct question | |||
| self.assertEqual(dnsrec.get_q(), resquestion) | |||
| # close the connection | |||
| wrr.write_eof() | |||
| # make sure it raises an error | |||
| with self.assertRaises(RuntimeError): | |||
| await ans | |||
| @async_test | |||
| async def test_questionresolver(self): | |||
| @@ -211,3 +279,66 @@ class Tests(unittest.TestCase): | |||
| # veryify the received answer | |||
| ans = await ans | |||
| self.assertEqual(ans, resanswer) | |||
| @async_test | |||
| async def test_outoforderanswer(self): | |||
| client, server = _asyncsockpair() | |||
| res = ServerResolver(await client) | |||
| rdr, wrr = await server | |||
| # constants | |||
| resquestiona = DNSQuestion('example.com') | |||
| resquestionb = DNSQuestion('example.net') | |||
| resanswera = RR.fromZone('example.com. A 192.0.2.10\nexample.com. A 192.0.2.11') | |||
| resanswerb = RR.fromZone('example.net. A 192.0.2.20\nexample.net. A 192.0.2.21') | |||
| # start the first query | |||
| ansa = asyncio.create_task(res.resolve(resquestiona)) | |||
| # Fetch the first question | |||
| question = await _readpkt(rdr) | |||
| dnsreca = DNSRecord.parse(question) | |||
| # Make sure we got the correct question | |||
| self.assertEqual(dnsreca.get_q(), resquestiona) | |||
| # start the second query (before first has been answered) | |||
| ansb = asyncio.create_task(res.resolve(resquestionb)) | |||
| # Fetch the second question | |||
| question = await _readpkt(rdr) | |||
| dnsrecb = DNSRecord.parse(question) | |||
| # Make sure we got the correct question | |||
| self.assertEqual(dnsrecb.get_q(), resquestionb) | |||
| # Generate the second reply | |||
| rep = dnsrecb.reply() | |||
| rep.add_answer(*resanswerb) | |||
| repbytes = rep.pack() | |||
| # Send the reply | |||
| _writepkt(wrr, repbytes) | |||
| # Generate the first reply | |||
| rep = dnsreca.reply() | |||
| rep.add_answer(*resanswera) | |||
| repbytes = rep.pack() | |||
| # Send the reply | |||
| _writepkt(wrr, repbytes) | |||
| # Send a second reply, and make sure it is ignored | |||
| _writepkt(wrr, repbytes) | |||
| # close the connection | |||
| wrr.write_eof() | |||
| # veryify the first received answer | |||
| ansa = await ansa | |||
| self.assertEqual(ansa, resanswera) | |||
| # veryify the second received answer | |||
| ansb = await ansb | |||
| self.assertEqual(ansb, resanswerb) | |||