# Copyright 2021 John-Mark Gurney. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # import asyncio import contextlib import functools import itertools import os import sys import unittest from Strobe.Strobe import Strobe, KeccakF from Strobe.Strobe import AuthenticationFailed import lora_comms from lora_comms import make_pktbuf import multicast from util import * # Response to command will be the CMD and any arguments if needed. # The command is encoded as an unsigned byte CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms # The follow commands are queue up, but will be acknoledged when queued CMD_WAITFOR = 2 # arg: (length): waits for length seconds CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds CMD_PING = 4 # arg: (): a no op command CMD_SETUNSET = 5 # arg: (chan, val): sets chan to val CMD_ADV = 6 # arg: ([cnt]): advances to the next cnt (default 1) command CMD_CLEAR = 7 # arg: (): clears all future commands, but keeps current running class LORANode(object): '''Implement a LORANode initiator.''' SHARED_DOMAIN = b'com.funkthat.lora.irrigation.shared.v0.0.1' ECDHE_DOMAIN = b'com.funkthat.lora.irrigation.ecdhe.v0.0.1' MAC_LEN = 8 def __init__(self, syncdatagram, shared=None, ecdhe_key=None, resp_pub=None): self.sd = syncdatagram self.st = Strobe(self.SHARED_DOMAIN, F=KeccakF(800)) if shared is not None: self.st.key(shared) else: raise RuntimeError async def start(self): resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset') self.st.ratchet() pkt = await self.sendrecvvalid(b'confirm') if pkt != b'confirmed': raise RuntimeError('got invalid response: %s' % repr(pkt)) async def sendrecvvalid(self, msg): msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN) origstate = self.st.copy() while True: resp = await self.sd.sendtillrecv(msg, .50) #_debprint('got:', resp) # skip empty messages if len(resp) == 0: continue try: decmsg = self.st.recv_enc(resp[:-self.MAC_LEN]) self.st.recv_mac(resp[-self.MAC_LEN:]) break except AuthenticationFailed: # didn't get a valid packet, restore # state and retry #_debprint('failed') self.st.set_state_from(origstate) #_debprint('got rep:', repr(resp), repr(decmsg)) return decmsg @staticmethod def _encodeargs(*args): r = [] for i in args: r.append(i.to_bytes(4, byteorder='little')) return b''.join(r) async def _sendcmd(self, cmd, *args): cmdbyte = cmd.to_bytes(1, byteorder='little') resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args)) if resp[0:1] != cmdbyte: raise RuntimeError( 'response does not match, got: %s, expected: %s' % (repr(resp[0:1]), repr(cmdbyte))) async def waitfor(self, length): return await self._sendcmd(CMD_WAITFOR, length) async def runfor(self, chan, length): return await self._sendcmd(CMD_RUNFOR, chan, length) async def setunset(self, chan, val): return await self._sendcmd(CMD_SETUNSET, chan, val) async def ping(self): return await self._sendcmd(CMD_PING) async def adv(self, cnt=None): args = () if cnt is not None: args = (cnt, ) return await self._sendcmd(CMD_ADV, *args) async def clear(self): return await self._sendcmd(CMD_CLEAR) async def terminate(self): return await self._sendcmd(CMD_TERMINATE) class SyncDatagram(object): '''Base interface for a more simple synchronous interface.''' async def recv(self, timeout=None): #pragma: no cover '''Receive a datagram. If timeout is not None, wait that many seconds, and if nothing is received in that time, raise an asyncio.TimeoutError exception.''' raise NotImplementedError async def send(self, data): #pragma: no cover raise NotImplementedError async def sendtillrecv(self, data, freq): '''Send the datagram in data, every freq seconds until a datagram is received. If timeout seconds happen w/o receiving a datagram, then raise an TimeoutError exception.''' while True: #_debprint('sending:', repr(data)) await self.send(data) try: return await self.recv(freq) except asyncio.TimeoutError: pass class MulticastSyncDatagram(SyncDatagram): ''' An implementation of SyncDatagram that uses the provided multicast address maddr as the source/sink of the packets. Note that once created, the start coroutine needs to be await'd before being passed to a LORANode so that everything is running. ''' # Note: sent packets will be received. A similar method to # what was done in multicast.{to,from}_loragw could be done # here as well, that is passing in a set of packets to not # pass back up. def __init__(self, maddr): self.maddr = maddr self._ignpkts = set() async def start(self): self.mr = await multicast.create_multicast_receiver(self.maddr) self.mt = await multicast.create_multicast_transmitter( self.maddr) async def _recv(self): while True: pkt = await self.mr.recv() pkt = pkt[0] if pkt not in self._ignpkts: return pkt self._ignpkts.remove(pkt) async def recv(self, timeout=None): #pragma: no cover r = await asyncio.wait_for(self._recv(), timeout=timeout) return r async def send(self, data): #pragma: no cover self._ignpkts.add(bytes(data)) await self.mt.send(data) def close(self): '''Shutdown communications.''' self.mr.close() self.mr = None self.mt.close() self.mt = None def listsplit(lst, item): try: idx = lst.index(item) except ValueError: return lst, [] return lst[:idx], lst[idx + 1:] async def main(): import argparse from loraserv import DEFAULT_MADDR as maddr parser = argparse.ArgumentParser() parser.add_argument('-f', dest='schedfile', metavar='filename', type=str, help='Use commands from the file. One command per line.') parser.add_argument('-r', dest='client', metavar='module:function', type=str, help='Create a respondant instead of sending commands. Commands will be passed to the function.') parser.add_argument('-s', dest='shared_key', metavar='shared_key', type=str, required=True, help='The shared key (encoded as UTF-8) to use.') parser.add_argument('args', metavar='CMD_ARG', type=str, nargs='*', help='Various commands to send to the device.') args = parser.parse_args() shared_key = args.shared_key.encode('utf-8') if args.client: # Run a client mr = await multicast.create_multicast_receiver(maddr) mt = await multicast.create_multicast_transmitter(maddr) from ctypes import c_uint8 # seed the RNG prngseed = os.urandom(64) lora_comms.strobe_seed_prng((c_uint8 * len(prngseed))(*prngseed), len(prngseed)) # Create the state for testing commstate = lora_comms.CommsState() import util_load client_func = util_load.load_application(args.client) def client_call(msg, outbuf): ret = client_func(msg._from()) if len(ret) > outbuf[0].pktlen: ret = b'error, too long buffer: %d' % len(ret) outbuf[0].pktlen = min(len(ret), outbuf[0].pktlen) for i in range(outbuf[0].pktlen): outbuf[0].pkt[i] = ret[i] cb = lora_comms.process_msgfunc_t(client_call) # Initialize everything lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) try: while True: pkt = await mr.recv() msg = pkt[0] out = lora_comms.comms_process_wrap( commstate, msg) if out: await mt.send(out) finally: mr.close() mt.close() sys.exit(0) msd = MulticastSyncDatagram(maddr) await msd.start() l = LORANode(msd, shared=shared_key) await l.start() valid_cmds = { 'waitfor', 'setunset', 'runfor', 'ping', 'adv', 'clear', 'terminate', } if args.args and args.schedfile: parser.error('only one of -f or arguments can be specified.') if args.args: cmds = list(args.args) cmdargs = [] while cmds: a, cmds = listsplit(cmds, '--') cmdargs.append(a) else: with open(args.schedfile) as fp: cmdargs = [ x.split() for x in fp.readlines() ] while cmdargs: cmd, *args = cmdargs.pop(0) if cmd not in valid_cmds: print('invalid command:', repr(cmd)) sys.exit(1) fun = getattr(l, cmd) await fun(*(int(x) for x in args)) if __name__ == '__main__': asyncio.run(main()) class MockSyncDatagram(SyncDatagram): '''A testing version of SyncDatagram. Define a method runner which implements part of the sequence. In the function, await on either self.get, to wait for the other side to send something, or await self.put w/ data to send.''' def __init__(self): self.sendq = asyncio.Queue() self.recvq = asyncio.Queue() self.task = asyncio.create_task(self.runner()) self.get = self.sendq.get self.put = self.recvq.put async def drain(self): '''Wait for the runner thread to finish up.''' return await self.task async def runner(self): #pragma: no cover raise NotImplementedError async def recv(self, timeout=None): return await self.recvq.get() async def send(self, data): return await self.sendq.put(data) def __del__(self): #pragma: no cover if self.task is not None and not self.task.done(): self.task.cancel() class TestSyncData(unittest.IsolatedAsyncioTestCase): async def test_syncsendtillrecv(self): class MySync(SyncDatagram): def __init__(self): self.sendq = [] self.resp = [ asyncio.TimeoutError(), b'a' ] async def recv(self, timeout=None): assert timeout == 1 r = self.resp.pop(0) if isinstance(r, Exception): raise r return r async def send(self, data): self.sendq.append(data) ms = MySync() r = await ms.sendtillrecv(b'foo', 1) self.assertEqual(r, b'a') self.assertEqual(ms.sendq, [ b'foo', b'foo' ]) class AsyncSequence(object): ''' Object used for sequencing async functions. To use, use the asynchronous context manager created by the sync method. For example: seq = AsyncSequence() async func1(): async with seq.sync(1): second_fun() async func2(): async with seq.sync(0): first_fun() This will make sure that function first_fun is run before running the function second_fun. If a previous block raises an Exception, it will be passed up, and all remaining blocks (and future ones) will raise a CancelledError to help ensure that any tasks are properly cleaned up. ''' def __init__(self, positerfactory=lambda: itertools.count()): '''The argument positerfactory, is a factory that will create an iterator that will be used for the values that are passed to the sync method.''' self.positer = positerfactory() self.token = object() self.die = False self.waiting = { next(self.positer): self.token } async def simpsync(self, pos): async with self.sync(pos): pass @contextlib.asynccontextmanager async def sync(self, pos): '''An async context manager that will be run when it's turn arrives. It will only run when all the previous items in the iterator has been successfully run.''' if self.die: raise asyncio.CancelledError('seq cancelled') if pos in self.waiting: if self.waiting[pos] is not self.token: raise RuntimeError('pos already waiting!') else: fut = asyncio.Future() self.waiting[pos] = fut await fut # our time to shine! del self.waiting[pos] try: yield None except Exception as e: # if we got an exception, things went pear shaped, # shut everything down, and any future calls. #_debprint('dieing...', repr(e)) self.die = True # cancel existing blocks while self.waiting: k, v = self.waiting.popitem() #_debprint('canceling: %s' % repr(k)) if v is self.token: continue # for Python 3.9: # msg='pos %s raised exception: %s' % # (repr(pos), repr(e)) v.cancel() # populate real exception up raise else: # handle next nextpos = next(self.positer) if nextpos in self.waiting: #_debprint('np:', repr(self), nextpos, # repr(self.waiting[nextpos])) self.waiting[nextpos].set_result(None) else: self.waiting[nextpos] = self.token class TestSequencing(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_seq_alreadywaiting(self): waitseq = AsyncSequence() seq = AsyncSequence() async def fun1(): async with waitseq.sync(1): pass async def fun2(): async with seq.sync(1): async with waitseq.sync(1): # pragma: no cover pass task1 = asyncio.create_task(fun1()) task2 = asyncio.create_task(fun2()) # spin things to make sure things advance await asyncio.sleep(0) async with seq.sync(0): pass with self.assertRaises(RuntimeError): await task2 async with waitseq.sync(0): pass await task1 @timeout(2) async def test_seqexc(self): seq = AsyncSequence() excseq = AsyncSequence() async def excfun1(): async with seq.sync(1): pass async with excseq.sync(0): raise ValueError('foo') # that a block that enters first, but runs after # raises an exception async def excfun2(): async with seq.sync(0): pass async with excseq.sync(1): # pragma: no cover pass # that a block that enters after, raises an # exception async def excfun3(): async with seq.sync(2): pass async with excseq.sync(2): # pragma: no cover pass task1 = asyncio.create_task(excfun1()) task2 = asyncio.create_task(excfun2()) task3 = asyncio.create_task(excfun3()) with self.assertRaises(ValueError): await task1 with self.assertRaises(asyncio.CancelledError): await task2 with self.assertRaises(asyncio.CancelledError): await task3 @timeout(2) async def test_seq(self): # test that a seq object when created seq = AsyncSequence(lambda: itertools.count(1)) col = [] async def fun1(): async with seq.sync(1): col.append(1) async with seq.sync(2): col.append(2) async with seq.sync(4): col.append(4) async def fun2(): async with seq.sync(3): col.append(3) async with seq.sync(6): col.append(6) async def fun3(): async with seq.sync(5): col.append(5) # and various functions are run task1 = asyncio.create_task(fun1()) task2 = asyncio.create_task(fun2()) task3 = asyncio.create_task(fun3()) # and the functions complete await task3 await task2 await task1 # that the order they ran in was correct self.assertEqual(col, list(range(1, 7))) class TestX25519(unittest.TestCase): def test_basic(self): aprivkey = lora_comms.x25519_genkey() apubkey = lora_comms.x25519_base(aprivkey, 1) bprivkey = lora_comms.x25519_genkey() bpubkey = lora_comms.x25519_base(bprivkey, 1) self.assertNotEqual(aprivkey, bprivkey) self.assertNotEqual(apubkey, bpubkey) ra = lora_comms.x25519_wrap(apubkey, aprivkey, bpubkey, 1) rb = lora_comms.x25519_wrap(bpubkey, bprivkey, apubkey, 1) self.assertEqual(ra, rb) class TestLORANode(unittest.IsolatedAsyncioTestCase): shared_domain = b'com.funkthat.lora.irrigation.shared.v0.0.1' def test_initparams(self): # make sure no keys fails with self.assertRaises(RuntimeError): l = LORANode(None) @timeout(2) async def test_lora_shared(self): _self = self shared_key = os.urandom(32) class TestSD(MockSyncDatagram): async def sendgettest(self, msg): '''Send the message, but make sure that if a bad message is sent afterward, that it replies w/ the same previous message. ''' await self.put(msg) resp = await self.get() await self.put(b'bogusmsg' * 5) resp2 = await self.get() _self.assertEqual(resp, resp2) return resp async def runner(self): l = Strobe(TestLORANode.shared_domain, F=KeccakF(800)) l.key(shared_key) # start handshake r = await self.get() pkt = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert pkt.endswith(b'reqreset') # make sure junk gets ignored await self.put(b'sdlfkj') # and that the packet remains the same _self.assertEqual(r, await self.get()) # and a couple more times await self.put(b'0' * 24) _self.assertEqual(r, await self.get()) await self.put(b'0' * 32) _self.assertEqual(r, await self.get()) # send the response await self.put(l.send_enc(os.urandom(16)) + l.send_mac(8)) # require no more back tracking at this point l.ratchet() # get the confirmation message r = await self.get() # test the resend capabilities await self.put(b'0' * 24) _self.assertEqual(r, await self.get()) # decode confirmation message c = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) # assert that we got it _self.assertEqual(c, b'confirm') # send confirmed reply r = await self.sendgettest(l.send_enc( b'confirmed') + l.send_mac(8)) # test and decode remaining command messages cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_WAITFOR assert int.from_bytes(cmd[1:], byteorder='little') == 30 r = await self.sendgettest(l.send_enc( cmd[0:1]) + l.send_mac(8)) cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_RUNFOR assert int.from_bytes(cmd[1:5], byteorder='little') == 1 assert int.from_bytes(cmd[5:], byteorder='little') == 50 r = await self.sendgettest(l.send_enc( cmd[0:1]) + l.send_mac(8)) cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_TERMINATE await self.put(l.send_enc(cmd[0:1]) + l.send_mac(8)) tsd = TestSD() l = LORANode(tsd, shared=shared_key) await l.start() await l.waitfor(30) await l.runfor(1, 50) await l.terminate() await tsd.drain() # Make sure all messages have been processed self.assertTrue(tsd.sendq.empty()) self.assertTrue(tsd.recvq.empty()) #_debprint('done') @timeout(2) async def test_ccode_badmsgs(self): # Test to make sure that various bad messages in the # handshake process are rejected even if the attacker # has the correct key. This just keeps the protocol # tight allowing for variations in the future. # seed the RNG prngseed = b'abc123' from ctypes import c_uint8 lora_comms.strobe_seed_prng((c_uint8 * len(prngseed))(*prngseed), len(prngseed)) # Create the state for testing commstate = lora_comms.CommsState() cb = lora_comms.process_msgfunc_t(lambda msg, outbuf: None) # Generate shared key shared_key = os.urandom(32) # Initialize everything lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) # Create test fixture, only use it to init crypto state tsd = SyncDatagram() l = LORANode(tsd, shared=shared_key) # copy the crypto state cstate = l.st.copy() # compose an incorrect init message msg = os.urandom(16) + b'othre' msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) out = lora_comms.comms_process_wrap(commstate, msg) self.assertFalse(out) # that varous short messages don't cause problems for i in range(10): out = lora_comms.comms_process_wrap(commstate, b'0' * i) self.assertFalse(out) # copy the crypto state cstate = l.st.copy() # compose an incorrect init message msg = os.urandom(16) + b' eqreset' msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) out = lora_comms.comms_process_wrap(commstate, msg) self.assertFalse(out) # compose the correct init message msg = os.urandom(16) + b'reqreset' msg = l.st.send_enc(msg) + l.st.send_mac(l.MAC_LEN) out = lora_comms.comms_process_wrap(commstate, msg) l.st.recv_enc(out[:-l.MAC_LEN]) l.st.recv_mac(out[-l.MAC_LEN:]) l.st.ratchet() # copy the crypto state cstate = l.st.copy() # compose an incorrect confirmed message msg = b'onfirm' msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) out = lora_comms.comms_process_wrap(commstate, msg) self.assertFalse(out) # copy the crypto state cstate = l.st.copy() # compose an incorrect confirmed message msg = b' onfirm' msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) out = lora_comms.comms_process_wrap(commstate, msg) self.assertFalse(out) @timeout(2) async def test_ccode(self): _self = self from ctypes import c_uint8 # seed the RNG prngseed = b'abc123' lora_comms.strobe_seed_prng((c_uint8 * len(prngseed))(*prngseed), len(prngseed)) # Create the state for testing commstate = lora_comms.CommsState() # These are the expected messages and their arguments exptmsgs = [ (CMD_WAITFOR, [ 30 ]), (CMD_RUNFOR, [ 1, 50 ]), (CMD_PING, [ ]), (CMD_TERMINATE, [ ]), ] def procmsg(msg, outbuf): msgbuf = msg._from() cmd = msgbuf[0] args = [ int.from_bytes(msgbuf[x:x + 4], byteorder='little') for x in range(1, len(msgbuf), 4) ] if exptmsgs[0] == (cmd, args): exptmsgs.pop(0) outbuf[0].pkt[0] = cmd outbuf[0].pktlen = 1 else: #pragma: no cover raise RuntimeError('cmd not found') # wrap the callback function cb = lora_comms.process_msgfunc_t(procmsg) class CCodeSD(MockSyncDatagram): async def runner(self): for expectlen in [ 24, 17, 9, 9, 9, 9 ]: # get message inmsg = await self.get() # process the test message out = lora_comms.comms_process_wrap( commstate, inmsg) # make sure the reply matches length _self.assertEqual(expectlen, len(out)) # save what was originally replied origmsg = out # pretend that the reply didn't make it out = lora_comms.comms_process_wrap( commstate, inmsg) # make sure that the reply matches # the previous _self.assertEqual(origmsg, out) # pass the reply back await self.put(out) # Generate shared key shared_key = os.urandom(32) # Initialize everything lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) # Create test fixture tsd = CCodeSD() l = LORANode(tsd, shared=shared_key) # Send various messages await l.start() await l.waitfor(30) await l.runfor(1, 50) await l.ping() await l.terminate() await tsd.drain() # Make sure all messages have been processed self.assertTrue(tsd.sendq.empty()) self.assertTrue(tsd.recvq.empty()) # Make sure all the expected messages have been # processed. self.assertFalse(exptmsgs) #_debprint('done') @timeout(2) async def test_ccode_newsession(self): '''This test is to make sure that if an existing session is running, that a new session can be established, and that when it does, the old session becomes inactive. ''' _self = self from ctypes import c_uint8 seq = AsyncSequence() # seed the RNG prngseed = b'abc123' lora_comms.strobe_seed_prng((c_uint8 * len(prngseed))(*prngseed), len(prngseed)) # Create the state for testing commstate = lora_comms.CommsState() # These are the expected messages and their arguments exptmsgs = [ (CMD_WAITFOR, [ 30 ]), (CMD_WAITFOR, [ 70 ]), (CMD_WAITFOR, [ 40 ]), (CMD_TERMINATE, [ ]), ] def procmsg(msg, outbuf): msgbuf = msg._from() cmd = msgbuf[0] args = [ int.from_bytes(msgbuf[x:x + 4], byteorder='little') for x in range(1, len(msgbuf), 4) ] if exptmsgs[0] == (cmd, args): exptmsgs.pop(0) outbuf[0].pkt[0] = cmd outbuf[0].pktlen = 1 else: #pragma: no cover raise RuntimeError('cmd not found: %d' % cmd) # wrap the callback function cb = lora_comms.process_msgfunc_t(procmsg) class FlipMsg(object): async def flipmsg(self): # get message inmsg = await self.get() # process the test message out = lora_comms.comms_process_wrap( commstate, inmsg) # pass the reply back await self.put(out) # this class always passes messages, this is # used for the first session. class CCodeSD1(MockSyncDatagram, FlipMsg): async def runner(self): for i in range(3): await self.flipmsg() async with seq.sync(0): # create bogus message inmsg = b'0'*24 # process the bogus message out = lora_comms.comms_process_wrap( commstate, inmsg) # make sure there was not a response _self.assertFalse(out) await self.flipmsg() # this one is special in that it will pause after the first # message to ensure that the previous session will continue # to work, AND that if a new "new" session comes along, it # will override the previous new session that hasn't been # confirmed yet. class CCodeSD2(MockSyncDatagram, FlipMsg): async def runner(self): # pass one message from the new session async with seq.sync(1): # There might be a missing case # handled for when the confirmed # message is generated, but lost. await self.flipmsg() # and the old session is still active await l.waitfor(70) async with seq.sync(2): for i in range(3): await self.flipmsg() # Generate shared key shared_key = os.urandom(32) # Initialize everything lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) # Create test fixture tsd = CCodeSD1() l = LORANode(tsd, shared=shared_key) # Send various messages await l.start() await l.waitfor(30) # Ensure that a new one can take over tsd2 = CCodeSD2() l2 = LORANode(tsd2, shared=shared_key) # Send various messages await l2.start() await l2.waitfor(40) await l2.terminate() await tsd.drain() await tsd2.drain() # Make sure all messages have been processed self.assertTrue(tsd.sendq.empty()) self.assertTrue(tsd.recvq.empty()) self.assertTrue(tsd2.sendq.empty()) self.assertTrue(tsd2.recvq.empty()) # Make sure all the expected messages have been # processed. self.assertFalse(exptmsgs) class TestLoRaNodeMulticast(unittest.IsolatedAsyncioTestCase): # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1 maddr = ('224.0.0.198', 48542) @timeout(2) async def test_multisyncdgram(self): # Test the implementation of the multicast version of # SyncDatagram _self = self from ctypes import c_uint8 # seed the RNG prngseed = b'abc123' lora_comms.strobe_seed_prng((c_uint8 * len(prngseed))(*prngseed), len(prngseed)) # Create the state for testing commstate = lora_comms.CommsState() # These are the expected messages and their arguments exptmsgs = [ (CMD_WAITFOR, [ 30 ]), (CMD_PING, [ ]), (CMD_TERMINATE, [ ]), ] def procmsg(msg, outbuf): msgbuf = msg._from() cmd = msgbuf[0] args = [ int.from_bytes(msgbuf[x:x + 4], byteorder='little') for x in range(1, len(msgbuf), 4) ] if exptmsgs[0] == (cmd, args): exptmsgs.pop(0) outbuf[0].pkt[0] = cmd outbuf[0].pktlen = 1 else: #pragma: no cover raise RuntimeError('cmd not found') # wrap the callback function cb = lora_comms.process_msgfunc_t(procmsg) # Generate shared key shared_key = os.urandom(32) # Initialize everything lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) # create the object we are testing msd = MulticastSyncDatagram(self.maddr) seq = AsyncSequence() async def clienttask(): mr = await multicast.create_multicast_receiver( self.maddr) mt = await multicast.create_multicast_transmitter( self.maddr) try: # make sure the above threads are running await seq.simpsync(0) while True: pkt = await mr.recv() msg = pkt[0] out = lora_comms.comms_process_wrap( commstate, msg) if out: await mt.send(out) finally: mr.close() mt.close() task = asyncio.create_task(clienttask()) # start it await msd.start() # pass it to a node l = LORANode(msd, shared=shared_key) await seq.simpsync(1) # Send various messages await l.start() await l.waitfor(30) await l.ping() await l.terminate() # shut things down ln = None msd.close() task.cancel() with self.assertRaises(asyncio.CancelledError): await task