Browse Source

(tests) Ensure pointer loop does not break the server

main
Michal Charemza 6 years ago
parent
commit
4ae35896bb
No known key found for this signature in database GPG Key ID: 4BBAF0F6B73C4363
1 changed files with 58 additions and 0 deletions
  1. +58
    -0
      test.py

+ 58
- 0
test.py View File

@@ -1,5 +1,7 @@
import asyncio
import ipaddress
import socket
import struct
import unittest


@@ -10,6 +12,8 @@ from aiodnsresolver import (
DnsResponseCode,
IPv4AddressExpiresAt,
Message,
ResourceRecord,
QuestionRecord,
Resolver,
pack,
parse,
@@ -213,6 +217,60 @@ class TestProxy(unittest.TestCase):
for response in responses:
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)

@async_test
async def test_sending_pointer_loop_not_affect_later_queries_c(self):
resolve, clear_cache = get_resolver(53)
self.add_async_cleanup(clear_cache)

start = DnsProxy(rules=((r'(^.*$)', r'\1'),))
stop = await start()
self.add_async_cleanup(stop)

response = await resolve('www.google.com', TYPES.A)
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)

name = b'mydomain.com'
question_record = QuestionRecord(name, TYPES.A, qclass=1)
record_1 = ResourceRecord(
name=name, qtype=TYPES.A, qclass=1, ttl=0,
rdata=ipaddress.IPv4Address('123.100.124.1').packed,
)
response = Message(
qid=1, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
qd=(question_record,), an=(record_1,), ns=(), ar=(),
)

data = pack(response)
packed_name = b''.join(
component
for label in name.split(b'.')
for component in (bytes([len(label)]), label)
) + b'\0'

occurance_1 = data.index(packed_name)
occurance_1_end = occurance_1 + len(packed_name)
occurance_2 = occurance_1_end + data[occurance_1_end:].index(packed_name)
occurance_2_end = occurance_2 + len(packed_name)

data_compressed = \
data[:occurance_2] + \
struct.pack('!H', (192 * 256) + occurance_2 + 4) + \
struct.pack('!H', (192 * 256) + occurance_2) + \
struct.pack('!H', (192 * 256) + occurance_2 + 2) + \
data[occurance_2_end:]

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.sendto(data_compressed, ('127.0.0.1', 53))
sock.close()

tasks = [
asyncio.create_task(resolve('www.google.com', TYPES.A))
for _ in range(0, 100000)
]
responses = await asyncio.gather(*tasks)
for response in responses:
self.assertEqual(type(response[0]), IPv4AddressExpiresAt)


def get_socket(port):
def _get_socket():


Loading…
Cancel
Save