|
|
@@ -75,16 +75,58 @@ class ServerResolver(object): |
|
|
|
|
|
|
|
self._reader, self._writer = rdrwrr |
|
|
|
|
|
|
|
self._qs = {} |
|
|
|
|
|
|
|
self._replies = asyncio.create_task(self.procreplies(self._reader)) |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if self._replies: |
|
|
|
self._replies.cancel() |
|
|
|
self._replies = None |
|
|
|
|
|
|
|
self._writer.write_eof() |
|
|
|
|
|
|
|
def cleanup(self): |
|
|
|
while self._qs: |
|
|
|
i = self._qs.popitem()[1] |
|
|
|
i.set_exception(RuntimeError( |
|
|
|
'server connection closed, ' |
|
|
|
'no response received')) |
|
|
|
|
|
|
|
async def procreplies(self, rdr): |
|
|
|
while True: |
|
|
|
# Get a response |
|
|
|
try: |
|
|
|
pkt = await _readpkt(self._reader) |
|
|
|
except asyncio.streams.IncompleteReadError: |
|
|
|
self.cleanup() |
|
|
|
return |
|
|
|
|
|
|
|
resp = DNSRecord.parse(pkt) |
|
|
|
|
|
|
|
# fetch where we should send it |
|
|
|
fut = self._qs.pop(resp.header.id, None) |
|
|
|
if fut is None: |
|
|
|
continue |
|
|
|
|
|
|
|
# send the result |
|
|
|
fut.set_result(resp.rr) |
|
|
|
|
|
|
|
async def resolve(self, q): |
|
|
|
# make DNS packet |
|
|
|
pkt = DNSRecord(questions=[q]) |
|
|
|
pktbytes = pkt.pack() |
|
|
|
|
|
|
|
_writepkt(self._writer, pktbytes) |
|
|
|
# Where the procreplies function will send the answer |
|
|
|
res = asyncio.Future() |
|
|
|
|
|
|
|
# Record it |
|
|
|
self._qs[pkt.header.id] = res |
|
|
|
|
|
|
|
pkt = await _readpkt(self._reader) |
|
|
|
resp = DNSRecord.parse(pkt) |
|
|
|
# Send query |
|
|
|
_writepkt(self._writer, pktbytes) |
|
|
|
|
|
|
|
return resp.rr |
|
|
|
return await res |
|
|
|
|
|
|
|
class DNSProc(asyncio.DatagramProtocol): |
|
|
|
def connection_made(self, transport): |
|
|
@@ -165,18 +207,44 @@ class Tests(unittest.TestCase): |
|
|
|
bytes.fromhex('E06D44B80B8F1D39A95C0B0D7C65D08458E880409BBC683457104237C7F8EC8D')) ]) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def xtest_query(self): |
|
|
|
async def xtest_query(self): # pragma: no cover |
|
|
|
# test that the query function |
|
|
|
q = await query_record('com.', 'DS') |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_realresolver(self): # pragma: no cover |
|
|
|
async def xtest_realresolver(self): # pragma: no cover |
|
|
|
# This is to test against a real resolver. |
|
|
|
host = 'gold' |
|
|
|
res = ServerResolver(await asyncio.open_connection(host, 53)) |
|
|
|
|
|
|
|
resquestion = DNSQuestion('example.com') |
|
|
|
print(repr(await res.resolve(resquestion))) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_resolvernoanswer(self): |
|
|
|
client, server = _asyncsockpair() |
|
|
|
res = ServerResolver(await client) |
|
|
|
|
|
|
|
rdr, wrr = await server |
|
|
|
|
|
|
|
# constants |
|
|
|
resquestion = DNSQuestion('example.com') |
|
|
|
|
|
|
|
# start the query |
|
|
|
ans = asyncio.create_task(res.resolve(resquestion)) |
|
|
|
|
|
|
|
# Fetch the question |
|
|
|
question = await _readpkt(rdr) |
|
|
|
dnsrec = DNSRecord.parse(question) |
|
|
|
|
|
|
|
# Make sure we got the correct question |
|
|
|
self.assertEqual(dnsrec.get_q(), resquestion) |
|
|
|
|
|
|
|
# close the connection |
|
|
|
wrr.write_eof() |
|
|
|
|
|
|
|
# make sure it raises an error |
|
|
|
with self.assertRaises(RuntimeError): |
|
|
|
await ans |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_questionresolver(self): |
|
|
@@ -211,3 +279,66 @@ class Tests(unittest.TestCase): |
|
|
|
# veryify the received answer |
|
|
|
ans = await ans |
|
|
|
self.assertEqual(ans, resanswer) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_outoforderanswer(self): |
|
|
|
client, server = _asyncsockpair() |
|
|
|
res = ServerResolver(await client) |
|
|
|
|
|
|
|
rdr, wrr = await server |
|
|
|
|
|
|
|
# constants |
|
|
|
resquestiona = DNSQuestion('example.com') |
|
|
|
resquestionb = DNSQuestion('example.net') |
|
|
|
resanswera = RR.fromZone('example.com. A 192.0.2.10\nexample.com. A 192.0.2.11') |
|
|
|
resanswerb = RR.fromZone('example.net. A 192.0.2.20\nexample.net. A 192.0.2.21') |
|
|
|
|
|
|
|
# start the first query |
|
|
|
ansa = asyncio.create_task(res.resolve(resquestiona)) |
|
|
|
|
|
|
|
# Fetch the first question |
|
|
|
question = await _readpkt(rdr) |
|
|
|
dnsreca = DNSRecord.parse(question) |
|
|
|
|
|
|
|
# Make sure we got the correct question |
|
|
|
self.assertEqual(dnsreca.get_q(), resquestiona) |
|
|
|
|
|
|
|
# start the second query (before first has been answered) |
|
|
|
ansb = asyncio.create_task(res.resolve(resquestionb)) |
|
|
|
|
|
|
|
# Fetch the second question |
|
|
|
question = await _readpkt(rdr) |
|
|
|
dnsrecb = DNSRecord.parse(question) |
|
|
|
|
|
|
|
# Make sure we got the correct question |
|
|
|
self.assertEqual(dnsrecb.get_q(), resquestionb) |
|
|
|
|
|
|
|
# Generate the second reply |
|
|
|
rep = dnsrecb.reply() |
|
|
|
rep.add_answer(*resanswerb) |
|
|
|
repbytes = rep.pack() |
|
|
|
|
|
|
|
# Send the reply |
|
|
|
_writepkt(wrr, repbytes) |
|
|
|
|
|
|
|
# Generate the first reply |
|
|
|
rep = dnsreca.reply() |
|
|
|
rep.add_answer(*resanswera) |
|
|
|
repbytes = rep.pack() |
|
|
|
|
|
|
|
# Send the reply |
|
|
|
_writepkt(wrr, repbytes) |
|
|
|
|
|
|
|
# Send a second reply, and make sure it is ignored |
|
|
|
_writepkt(wrr, repbytes) |
|
|
|
|
|
|
|
# close the connection |
|
|
|
wrr.write_eof() |
|
|
|
|
|
|
|
# veryify the first received answer |
|
|
|
ansa = await ansa |
|
|
|
self.assertEqual(ansa, resanswera) |
|
|
|
|
|
|
|
# veryify the second received answer |
|
|
|
ansb = await ansb |
|
|
|
self.assertEqual(ansb, resanswerb) |