| @@ -43,7 +43,7 @@ def get_logger_default(): | |||||
| def DnsProxy( | def DnsProxy( | ||||
| get_resolver=get_resolver_default, get_logger=get_logger_default, | get_resolver=get_resolver_default, get_logger=get_logger_default, | ||||
| get_socket=get_socket_default, | get_socket=get_socket_default, | ||||
| num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000, | |||||
| num_workers=1000, upstream_queue_maxsize=10000, | |||||
| rules=(), | rules=(), | ||||
| ): | ): | ||||
| @@ -61,20 +61,12 @@ def DnsProxy( | |||||
| # workers | # workers | ||||
| async def server_worker(sock, resolve): | async def server_worker(sock, resolve): | ||||
| downstream_queue = Queue(maxsize=downstream_queue_maxsize) | |||||
| upstream_queue = Queue(maxsize=upstream_queue_maxsize) | upstream_queue = Queue(maxsize=upstream_queue_maxsize) | ||||
| # It would "usually" be ok to send downstream from multiple tasks, but | |||||
| # if the socket has a full buffer, it would raise a BlockingIOError, | |||||
| # and we will need to attach a reader. We can only attach one reader | |||||
| # per underlying file, and since we have a single socket, we have a | |||||
| # single file. So we send downstream from a single task | |||||
| downstream_worker_task = create_task(downstream_worker(sock, downstream_queue)) | |||||
| # We have multiple upstream workers to be able to send multiple | # We have multiple upstream workers to be able to send multiple | ||||
| # requests upstream concurrently, and add responses to downstream_queue | # requests upstream concurrently, and add responses to downstream_queue | ||||
| upstream_worker_tasks = [ | upstream_worker_tasks = [ | ||||
| create_task(upstream_worker(resolve, upstream_queue, downstream_queue)) | |||||
| create_task(upstream_worker(sock, resolve, upstream_queue)) | |||||
| for _ in range(0, num_workers)] | for _ in range(0, num_workers)] | ||||
| try: | try: | ||||
| @@ -82,50 +74,31 @@ def DnsProxy( | |||||
| request_data, addr = await recvfrom(loop, [sock], 512) | request_data, addr = await recvfrom(loop, [sock], 512) | ||||
| await upstream_queue.put((request_data, addr)) | await upstream_queue.put((request_data, addr)) | ||||
| finally: | finally: | ||||
| # Finish upstream requests, which can add to to the downstream | |||||
| # queue | |||||
| # Finish upstream requests | |||||
| await upstream_queue.join() | await upstream_queue.join() | ||||
| for upstream_task in upstream_worker_tasks: | for upstream_task in upstream_worker_tasks: | ||||
| upstream_task.cancel() | upstream_task.cancel() | ||||
| # Ensure we have sent the responses downstream | |||||
| await downstream_queue.join() | |||||
| downstream_worker_task.cancel() | |||||
| # Wait for the tasks to really be finished | |||||
| for upstream_task in upstream_worker_tasks: | for upstream_task in upstream_worker_tasks: | ||||
| try: | try: | ||||
| await upstream_task | await upstream_task | ||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| try: | |||||
| await downstream_worker_task | |||||
| except Exception: | |||||
| pass | |||||
| async def upstream_worker(resolve, upstream_queue, downstream_queue): | |||||
| async def upstream_worker(sock, resolve, upstream_queue): | |||||
| while True: | while True: | ||||
| request_data, addr = await upstream_queue.get() | request_data, addr = await upstream_queue.get() | ||||
| try: | try: | ||||
| response_data = await get_response_data(resolve, request_data) | response_data = await get_response_data(resolve, request_data) | ||||
| # Sendto for non-blocking UDP sockets cannot raise a BlockingIOError | |||||
| # https://stackoverflow.com/a/59794872/1319998 | |||||
| sock.sendto(response_data, addr) | |||||
| except Exception: | except Exception: | ||||
| logger.exception('Exception from handler_request_data %s', addr) | |||||
| continue | |||||
| else: | |||||
| await downstream_queue.put((response_data, addr)) | |||||
| logger.exception('Processing request from %s', addr) | |||||
| finally: | finally: | ||||
| upstream_queue.task_done() | upstream_queue.task_done() | ||||
| async def downstream_worker(sock, downstream_queue): | |||||
| while True: | |||||
| response_data, addr = await downstream_queue.get() | |||||
| try: | |||||
| await sendto(sock, response_data, addr) | |||||
| except Exception: | |||||
| logger.exception('Unable to send response to %s', addr) | |||||
| downstream_queue.task_done() | |||||
| async def get_response_data(resolve, request_data): | async def get_response_data(resolve, request_data): | ||||
| # This may raise an exception, which is handled at a higher level. | # This may raise an exception, which is handled at a higher level. | ||||
| # We can't [and I suspect shouldn't try to] return an error to the | # We can't [and I suspect shouldn't try to] return an error to the | ||||
| @@ -206,17 +179,3 @@ def error(query, rcode): | |||||
| qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, | qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, | ||||
| qd=query.qd, an=(), ns=(), ar=(), | qd=query.qd, an=(), ns=(), ar=(), | ||||
| ) | ) | ||||
| async def sendto(sock, data, addr): | |||||
| # In our cases, the UDP responses will always be 512 bytes or less. | |||||
| # Even if sendto sent some of the data, there is no way for the other | |||||
| # end to reconstruct their order, so we don't include any logic to send | |||||
| # the rest of the data. Since it's UDP, the client already has to have | |||||
| # retry logic. | |||||
| # | |||||
| # Potentially also, this can raise a BlockingIOError, but even trying | |||||
| # to force high numbers of messages with a small socket buffer, this has | |||||
| # never been observed. As above, the client must have retry logic, so we | |||||
| # leave it to the client to deal with this. | |||||
| return sock.sendto(data, addr) | |||||