import asyncio import random import socket import struct import unittest from util import * # This function based upon code from: # https://gist.github.com/petrdvor/e802bec72e78ace061ab9d4469418fae#file-async-multicast-receiver-server-py-L54-L72 def make_multisock(maddr): # family, type, proto, ??, addr) addrinfo = socket.getaddrinfo(*maddr, type=socket.SOCK_DGRAM)[0] sock = socket.socket(*addrinfo[:2]) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.bind(maddr) group_bin = socket.inet_pton(addrinfo[0], addrinfo[4][0]) mreq = group_bin + struct.pack('=I', socket.INADDR_ANY) sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) return sock class StupidProtocol(object): def __init__(self): self.transport = None def close(self): return self.transport.close() def connection_lost(self, exc): self.transport = None def connection_made(self, transport): # Note: the connection_made call seems to be sync. This # isn't documented, and I don't know how to force a test # if it isn't. self.transport = transport class ReceiverProtocol(StupidProtocol): def __init__(self): super().__init__() self._q = asyncio.Queue() def datagram_received(self, data, addr): self._q.put_nowait((data, addr)) async def recv(self): return await self._q.get() async def create_multicast_receiver(maddr): sock = make_multisock(maddr) loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint( lambda: ReceiverProtocol(), sock=sock) return protocol class TransmitterProtocol(StupidProtocol): async def send(self, msg): self.transport.sendto(msg) async def create_multicast_transmitter(maddr): loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint( lambda: TransmitterProtocol(), remote_addr=maddr) return protocol class TestMulticast(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_multicast(self): # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1 maddr = ('224.0.0.199', 3485) l1 = await create_multicast_receiver(maddr) l2 = await create_multicast_receiver(maddr) t1 = await create_multicast_transmitter(maddr) msg = b'test message' await t1.send(msg) await t1.send(msg) self.assertEqual((await l1.recv())[0], msg) self.assertEqual((await l2.recv())[0], msg) self.assertEqual((await l1.recv())[0], msg) self.assertEqual((await l2.recv())[0], msg) t1.close() l1.close() l2.close()