From 08e6f939922f5b85169be1a266bde541c3c47f9c Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Tue, 4 May 2021 18:23:19 -0700 Subject: [PATCH] break out some util functions, implement multicast adapter for LORANode... --- lora.py | 187 ++++++++++++++++++++++++++++++++++++++++++++------- multicast.py | 4 +- util.py | 28 ++++++++ 3 files changed, 191 insertions(+), 28 deletions(-) create mode 100644 util.py diff --git a/lora.py b/lora.py index 6a830c1..a356077 100644 --- a/lora.py +++ b/lora.py @@ -34,6 +34,8 @@ from Strobe.Strobe import AuthenticationFailed import lora_comms from lora_comms import make_pktbuf +import multicast +from util import * domain = b'com.funkthat.lora.irrigation.shared.v0.0.1' @@ -127,9 +129,6 @@ class LORANode(object): class SyncDatagram(object): '''Base interface for a more simple synchronous interface.''' - def __init__(self): #pragma: no cover - pass - async def recv(self, timeout=None): #pragma: no cover '''Receive a datagram. If timeout is not None, wait that many seconds, and if nothing is received in that time, raise an @@ -154,6 +153,56 @@ class SyncDatagram(object): except TimeoutError: pass +class MulticastSyncDatagram(SyncDatagram): + ''' + An implementation of SyncDatagram that uses the provided + multicast address maddr as the source/sink of the packets. + + Note that once created, the start coroutine needs to be + await'd before being passed to a LORANode so that everything + is running. + ''' + + # Note: sent packets will be received. A similar method to + # what was done in multicast.{to,from}_loragw could be done + # here as well, that is passing in a set of packets to not + # pass back up. + + def __init__(self, maddr): + self.maddr = maddr + self._ignpkts = set() + + async def start(self): + self.mr = await multicast.create_multicast_receiver(self.maddr) + self.mt = await multicast.create_multicast_transmitter( + self.maddr) + + async def _recv(self): + while True: + pkt = await self.mr.recv() + pkt = pkt[0] + if pkt not in self._ignpkts: + return pkt + + self._ignpkts.remove(pkt) + + async def recv(self, timeout=None): #pragma: no cover + r = await asyncio.wait_for(self._recv(), timeout=timeout) + + return r + + async def send(self, data): #pragma: no cover + self._ignpkts.add(bytes(data)) + await self.mt.send(data) + + def close(self): + '''Shutdown communications.''' + + self.mr.close() + self.mr = None + self.mt.close() + self.mt = None + class MockSyncDatagram(SyncDatagram): '''A testing version of SyncDatagram. Define a method runner which implements part of the sequence. In the function, await on either @@ -211,29 +260,6 @@ class TestSyncData(unittest.IsolatedAsyncioTestCase): self.assertEqual(r, b'a') self.assertEqual(ms.sendq, [ b'foo', b'foo' ]) -def timeout(timeout): - def timeout_wrapper(fun): - @functools.wraps(fun) - async def wrapper(*args, **kwargs): - return await asyncio.wait_for(fun(*args, **kwargs), - timeout) - - return wrapper - - return timeout_wrapper - -def _debprint(*args): # pragma: no cover - import traceback, sys, os.path - st = traceback.extract_stack(limit=2)[0] - - sep = '' - if args: - sep = ':' - - print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep), - *args) - sys.stdout.flush() - class AsyncSequence(object): ''' Object used for sequencing async functions. To use, use the @@ -267,6 +293,10 @@ class AsyncSequence(object): next(self.positer): self.token } + async def simpsync(self, pos): + async with self.sync(pos): + pass + @contextlib.asynccontextmanager async def sync(self, pos): '''An async context manager that will be run when it's @@ -877,3 +907,108 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): # Make sure all the expected messages have been # processed. self.assertFalse(exptmsgs) + +class TestLoRaNodeMulticast(unittest.IsolatedAsyncioTestCase): + # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1 + maddr = ('224.0.0.198', 48542) + + @timeout(2) + async def test_multisyncdgram(self): + # Test the implementation of the multicast version of + # SyncDatagram + + _self = self + from ctypes import c_uint8 + + # seed the RNG + prngseed = b'abc123' + lora_comms.strobe_seed_prng((c_uint8 * + len(prngseed))(*prngseed), len(prngseed)) + + # Create the state for testing + commstate = lora_comms.CommsState() + + # These are the expected messages and their arguments + exptmsgs = [ + (CMD_WAITFOR, [ 30 ]), + (CMD_PING, [ ]), + (CMD_TERMINATE, [ ]), + ] + def procmsg(msg, outbuf): + msgbuf = msg._from() + cmd = msgbuf[0] + args = [ int.from_bytes(msgbuf[x:x + 4], + byteorder='little') for x in range(1, len(msgbuf), + 4) ] + + if exptmsgs[0] == (cmd, args): + exptmsgs.pop(0) + outbuf[0].pkt[0] = cmd + outbuf[0].pktlen = 1 + else: #pragma: no cover + raise RuntimeError('cmd not found') + + # wrap the callback function + cb = lora_comms.process_msgfunc_t(procmsg) + + # Generate shared key + shared_key = os.urandom(32) + + # Initialize everything + lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) + + # create the object we are testing + msd = MulticastSyncDatagram(self.maddr) + + seq = AsyncSequence() + + async def clienttask(): + mr = await multicast.create_multicast_receiver( + self.maddr) + mt = await multicast.create_multicast_transmitter( + self.maddr) + + try: + # make sure the above threads are running + await seq.simpsync(0) + + while True: + pkt = await mr.recv() + msg = pkt[0] + + out = lora_comms.comms_process_wrap( + commstate, msg) + + if out: + await mt.send(out) + finally: + mr.close() + mt.close() + + task = asyncio.create_task(clienttask()) + + # start it + await msd.start() + + # pass it to a node + l = LORANode(msd, shared=shared_key) + + await seq.simpsync(1) + + # Send various messages + await l.start() + + await l.waitfor(30) + + await l.ping() + + await l.terminate() + + # shut things down + ln = None + msd.close() + + task.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task diff --git a/multicast.py b/multicast.py index b9298d9..e40f562 100644 --- a/multicast.py +++ b/multicast.py @@ -4,7 +4,7 @@ import socket import struct import unittest -from lora import timeout +from util import * # This function based upon code from: # https://gist.github.com/petrdvor/e802bec72e78ace061ab9d4469418fae#file-async-multicast-receiver-server-py-L54-L72 @@ -77,7 +77,7 @@ class TestMulticast(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_multicast(self): # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1 - maddr = ('224.0.0.%d' % random.randint(151, 250), 3485) + maddr = ('224.0.0.199', 3485) l1 = await create_multicast_receiver(maddr) l2 = await create_multicast_receiver(maddr) diff --git a/util.py b/util.py new file mode 100644 index 0000000..2959bfb --- /dev/null +++ b/util.py @@ -0,0 +1,28 @@ +import asyncio +import functools + +__all__ = [ 'timeout', '_debprint' ] + +def timeout(timeout): + def timeout_wrapper(fun): + @functools.wraps(fun) + async def wrapper(*args, **kwargs): + return await asyncio.wait_for(fun(*args, **kwargs), + timeout) + + return wrapper + + return timeout_wrapper + +def _debprint(*args): # pragma: no cover + import traceback, sys, os.path + st = traceback.extract_stack(limit=2)[0] + + sep = '' + if args: + sep = ':' + + print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep), + *args) + sys.stdout.flush() +