diff --git a/comms.c b/comms.c index 21092bb..c38bc29 100644 --- a/comms.c +++ b/comms.c @@ -44,6 +44,13 @@ comms_pktbuf_equal(struct pktbuf a, struct pktbuf b) return memcmp(a.pkt, b.pkt, a.pktlen) == 0; } +size_t +_comms_state_size() +{ + + return sizeof(struct comms_state); +} + size_t _strobe_state_size() { @@ -56,7 +63,6 @@ comms_init(struct comms_state *cs, process_msgfunc_t pmf, struct pktbuf *shared) { *cs = (struct comms_state){ - .cs_comm_state = COMMS_WAIT_REQUEST, .cs_procmsg = pmf, }; @@ -66,45 +72,44 @@ comms_init(struct comms_state *cs, process_msgfunc_t pmf, struct pktbuf *shared) strobe_key(&cs->cs_start, SYM_KEY, shared->pkt, shared->pktlen); /* copy starting state over to initial state */ - cs->cs_state = cs->cs_start; + cs->cs_active = (struct comms_session){ + .cs_crypto = cs->cs_start, + .cs_state = COMMS_WAIT_REQUEST, + }; + cs->cs_pending = cs->cs_active; } #define CONFIRMED_STR_BASE "confirmed" #define CONFIRMED_STR ((const uint8_t *)CONFIRMED_STR_BASE) #define CONFIRMED_STR_LEN (sizeof(CONFIRMED_STR_BASE) - 1) -/* - * encrypted data to be processed is passed in via pbin. - * - * The pktbuf pointed to by pbout contains the buffer that a [encrypted] - * response will be written to. The length needs to be updated, where 0 - * means no reply. - */ -void -comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) +static void +_comms_process_session(struct comms_state *cs, struct comms_session *sess, struct pktbuf pbin, struct pktbuf *pbout) { + strobe_s tmp; uint8_t buf[64] = {}; struct pktbuf pbmsg, pbrep; ssize_t cnt, ret, msglen; - /* if the current msg matches the previous */ - if (comms_pktbuf_equal(pbin, cs->cs_prevmsg)) { - /* send the previous response */ - pbout->pktlen = cs->cs_prevmsgresp.pktlen; - memcpy(pbout->pkt, cs->cs_prevmsgresp.pkt, pbout->pktlen); - return; - } + /* save the state incase the message is bad */ + tmp = sess->cs_crypto; - strobe_attach_buffer(&cs->cs_state, pbin.pkt, pbin.pktlen); + strobe_attach_buffer(&sess->cs_crypto, pbin.pkt, pbin.pktlen); - cnt = strobe_get(&cs->cs_state, APP_CIPHERTEXT, buf, pbin.pktlen - + cnt = strobe_get(&sess->cs_crypto, APP_CIPHERTEXT, buf, pbin.pktlen - MAC_LEN); msglen = cnt; - cnt = strobe_get(&cs->cs_state, MAC, pbin.pkt + + cnt = strobe_get(&sess->cs_crypto, MAC, pbin.pkt + (pbin.pktlen - MAC_LEN), MAC_LEN); - /* XXX - cnt != MAC_LEN test case */ + /* MAC check failed */ + if (cnt == -1) { + /* restore the previous state */ + sess->cs_crypto = tmp; + pbout->pktlen = 0; + return; + } /* * if we have arrived here, MAC has been verified, and buf now @@ -112,29 +117,29 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) */ /* attach the buffer for output */ - strobe_attach_buffer(&cs->cs_state, pbout->pkt, pbout->pktlen); + strobe_attach_buffer(&sess->cs_crypto, pbout->pkt, pbout->pktlen); ret = 0; - switch (cs->cs_comm_state) { + switch (sess->cs_state) { case COMMS_WAIT_REQUEST: /* XXX - reqreset check */ bare_strobe_randomize(buf, CHALLENGE_LEN); - ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, buf, + ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, buf, CHALLENGE_LEN); - ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN); + ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); - strobe_operate(&cs->cs_state, RATCHET, NULL, 32); + strobe_operate(&sess->cs_crypto, RATCHET, NULL, 32); - cs->cs_comm_state = COMMS_WAIT_CONFIRM; + sess->cs_state = COMMS_WAIT_CONFIRM; break; case COMMS_WAIT_CONFIRM: /* XXX - confirm check */ - ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, CONFIRMED_STR, + ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, CONFIRMED_STR, CONFIRMED_STR_LEN); - ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN); - cs->cs_comm_state = COMMS_PROCESS_MSGS; + ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); + sess->cs_state = COMMS_PROCESS_MSGS; break; case COMMS_PROCESS_MSGS: { @@ -150,9 +155,9 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) cs->cs_procmsg(pbmsg, &pbrep); - ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, repbuf, + ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, repbuf, pbrep.pktlen); - ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN); + ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); break; } @@ -161,8 +166,36 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) /* set the output buffer length */ pbout->pktlen = ret; - if (ret != 0) { +} + +/* + * encrypted data to be processed is passed in via pbin. + * + * The pktbuf pointed to by pbout contains the buffer that a [encrypted] + * response will be written to. The length needs to be updated, where 0 + * means no reply. + */ +void +comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) +{ + struct pktbuf pbouttmp; + + /* if the current msg matches the previous */ + if (comms_pktbuf_equal(pbin, cs->cs_prevmsg)) { + /* send the previous response */ + pbout->pktlen = cs->cs_prevmsgresp.pktlen; + memcpy(pbout->pkt, cs->cs_prevmsgresp.pkt, pbout->pktlen); + return; + } + + /* try to use the active session */ + pbouttmp = *pbout; + _comms_process_session(cs, &cs->cs_active, pbin, &pbouttmp); + + if (pbouttmp.pktlen != 0) { +retmsg: /* we accepted a new message store it */ + *pbout = pbouttmp; /* store the req */ cs->cs_prevmsg.pkt = cs->cs_prevmsgbuf; @@ -173,5 +206,39 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) cs->cs_prevmsgresp.pkt = cs->cs_prevmsgrespbuf; cs->cs_prevmsgresp.pktlen = pbout->pktlen; memcpy(cs->cs_prevmsgresp.pkt, pbout->pkt, pbout->pktlen); + } else { + /* active session didn't work, try cs_pending */ + + pbouttmp = *pbout; + _comms_process_session(cs, &cs->cs_pending, pbin, &pbouttmp); + + if (cs->cs_pending.cs_state == COMMS_PROCESS_MSGS) { + /* new active state */ + cs->cs_active = cs->cs_pending; + cs->cs_pending = (struct comms_session){ + .cs_crypto = cs->cs_start, + .cs_state = COMMS_WAIT_REQUEST, + }; + goto retmsg; + } + + /* pending session didn't work, maybe new */ + struct comms_session tmpsess; + + tmpsess = (struct comms_session){ + .cs_crypto = cs->cs_start, + .cs_state = COMMS_WAIT_REQUEST, + }; + + pbouttmp = *pbout; + _comms_process_session(cs, &tmpsess, pbin, &pbouttmp); + if (tmpsess.cs_state == COMMS_WAIT_CONFIRM) { + /* new request for session */ + cs->cs_pending = tmpsess; + *pbout = pbouttmp; + } else { + /* no packet to reply with */ + pbout->pktlen = 0; + } } } diff --git a/comms.h b/comms.h index b1f60c4..b4e71dc 100644 --- a/comms.h +++ b/comms.h @@ -45,9 +45,33 @@ enum comm_state { COMMS_PROCESS_MSGS, }; +struct comms_session { + strobe_s cs_crypto; + enum comm_state cs_state; +}; + +/* + * Each message will be passed to each state. + * + * cs_active can be in any state. + * cs_pending can only be in a _WAIT_* state. + * + * When cs_pending advances to _PROCESS_MSGS, it will + * replace cs_active, and cs_pending w/ be copied from cache + * and set to _WAIT_REQUEST. + * + * If any message was not processed by the first to, a new session + * will be attempted w/ the _start crypto state, and if it progresses + * to _WAIT_CONFIG, it will replace cs_pending. + * + * We don't have to save the reply from a new session, because if the + * reply gets lost, the initiator will send the request again and we'll + * restart the session. + */ struct comms_state { - strobe_s cs_state; - enum comm_state cs_comm_state; + struct comms_session cs_active; /* current active session */ + struct comms_session cs_pending; /* current pending session */ + strobe_s cs_start; /* special starting state cache */ process_msgfunc_t cs_procmsg; @@ -60,6 +84,7 @@ struct comms_state { }; size_t _strobe_state_size(); +size_t _comms_state_size(); void comms_init(struct comms_state *, process_msgfunc_t, struct pktbuf *); void comms_process(struct comms_state *, struct pktbuf, struct pktbuf *); diff --git a/lora.py b/lora.py index 01da465..71b689d 100644 --- a/lora.py +++ b/lora.py @@ -23,7 +23,9 @@ # import asyncio +import contextlib import functools +import itertools import os import unittest @@ -42,6 +44,7 @@ CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms # The follow commands are queue up, but will be acknoledged when queued CMD_WAITFOR = 2 # arg: (length): waits for length seconds CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds +CMD_PING = 4 # arg: (): a no op command class LORANode(object): '''Implement a LORANode initiator.''' @@ -62,7 +65,8 @@ class LORANode(object): pkt = await self.sendrecvvalid(b'confirm') if pkt != b'confirmed': - raise RuntimeError + raise RuntimeError('got invalid response: %s' % + repr(pkt)) async def sendrecvvalid(self, msg): msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN) @@ -73,6 +77,10 @@ class LORANode(object): resp = await self.sd.sendtillrecv(msg, 1) #_debprint('got:', resp) + # skip empty messages + if len(resp) == 0: + continue + try: decmsg = self.st.recv_enc(resp[:-self.MAC_LEN]) self.st.recv_mac(resp[-self.MAC_LEN:]) @@ -110,6 +118,9 @@ class LORANode(object): async def runfor(self, chan, length): return await self._sendcmd(CMD_RUNFOR, chan, length) + async def ping(self): + return await self._sendcmd(CMD_PING) + async def terminate(self): return await self._sendcmd(CMD_TERMINATE) @@ -127,7 +138,6 @@ class SyncDatagram(object): raise NotImplementedError async def send(self, data): #pragma: no cover - '''Send a datagram.''' raise NotImplementedError @@ -153,7 +163,6 @@ class MockSyncDatagram(SyncDatagram): def __init__(self): self.sendq = asyncio.Queue() self.recvq = asyncio.Queue() - self.task = None self.task = asyncio.create_task(self.runner()) self.get = self.sendq.get @@ -225,6 +234,211 @@ def _debprint(*args): # pragma: no cover *args) sys.stdout.flush() +class AsyncSequence(object): + ''' + Object used for sequencing async functions. To use, use the + asynchronous context manager created by the sync method. For + example: + seq = AsyncSequence() + async func1(): + async with seq.sync(1): + second_fun() + + async func2(): + async with seq.sync(0): + first_fun() + + This will make sure that function first_fun is run before running + the function second_fun. If a previous block raises an Exception, + it will be passed up, and all remaining blocks (and future ones) + will raise a CancelledError to help ensure that any tasks are + properly cleaned up. + ''' + + def __init__(self, positerfactory=lambda: itertools.count()): + '''The argument positerfactory, is a factory that will + create an iterator that will be used for the values that + are passed to the sync method.''' + + self.positer = positerfactory() + self.token = object() + self.die = False + self.waiting = { + next(self.positer): self.token + } + + @contextlib.asynccontextmanager + async def sync(self, pos): + '''An async context manager that will be run when it's + turn arrives. It will only run when all the previous + items in the iterator has been successfully run.''' + + if self.die: + raise asyncio.CancelledError('seq cancelled') + + if pos in self.waiting: + if self.waiting[pos] is not self.token: + raise RuntimeError('pos already waiting!') + else: + fut = asyncio.Future() + self.waiting[pos] = fut + await fut + + # our time to shine! + del self.waiting[pos] + + try: + yield None + except Exception as e: + # if we got an exception, things went pear shaped, + # shut everything down, and any future calls. + + #_debprint('dieing...', repr(e)) + self.die = True + + # cancel existing blocks + while self.waiting: + k, v = self.waiting.popitem() + #_debprint('canceling: %s' % repr(k)) + if v is self.token: + continue + + # for Python 3.9: + # msg='pos %s raised exception: %s' % + # (repr(pos), repr(e)) + v.cancel() + + # populate real exception up + raise + else: + # handle next + nextpos = next(self.positer) + + if nextpos in self.waiting: + #_debprint('np:', repr(self), nextpos, + # repr(self.waiting[nextpos])) + self.waiting[nextpos].set_result(None) + else: + self.waiting[nextpos] = self.token + +class TestSequencing(unittest.IsolatedAsyncioTestCase): + @timeout(2) + async def test_seq_alreadywaiting(self): + waitseq = AsyncSequence() + + seq = AsyncSequence() + + async def fun1(): + async with waitseq.sync(1): + pass + + async def fun2(): + async with seq.sync(1): + async with waitseq.sync(1): # pragma: no cover + pass + + task1 = asyncio.create_task(fun1()) + task2 = asyncio.create_task(fun2()) + + # spin things to make sure things advance + await asyncio.sleep(0) + + async with seq.sync(0): + pass + + with self.assertRaises(RuntimeError): + await task2 + + async with waitseq.sync(0): + pass + + await task1 + + @timeout(2) + async def test_seqexc(self): + seq = AsyncSequence() + + excseq = AsyncSequence() + + async def excfun1(): + async with seq.sync(1): + pass + + async with excseq.sync(0): + raise ValueError('foo') + + # that a block that enters first, but runs after + # raises an exception + async def excfun2(): + async with seq.sync(0): + pass + + async with excseq.sync(1): # pragma: no cover + pass + + # that a block that enters after, raises an + # exception + async def excfun3(): + async with seq.sync(2): + pass + + async with excseq.sync(2): # pragma: no cover + pass + + task1 = asyncio.create_task(excfun1()) + task2 = asyncio.create_task(excfun2()) + task3 = asyncio.create_task(excfun3()) + + with self.assertRaises(ValueError): + await task1 + + with self.assertRaises(asyncio.CancelledError): + await task2 + + with self.assertRaises(asyncio.CancelledError): + await task3 + + @timeout(2) + async def test_seq(self): + # test that a seq object when created + seq = AsyncSequence(lambda: itertools.count(1)) + + col = [] + + async def fun1(): + async with seq.sync(1): + col.append(1) + + async with seq.sync(2): + col.append(2) + + async with seq.sync(4): + col.append(4) + + async def fun2(): + async with seq.sync(3): + col.append(3) + + async with seq.sync(6): + col.append(6) + + async def fun3(): + async with seq.sync(5): + col.append(5) + + # and various functions are run + task1 = asyncio.create_task(fun1()) + task2 = asyncio.create_task(fun2()) + task3 = asyncio.create_task(fun3()) + + # and the functions complete + await task3 + await task2 + await task1 + + # that the order they ran in was correct + self.assertEqual(col, list(range(1, 7))) + class TestLORANode(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_lora(self): @@ -365,11 +579,11 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): exptmsgs = [ (CMD_WAITFOR, [ 30 ]), (CMD_RUNFOR, [ 1, 50 ]), + (CMD_PING, [ ]), (CMD_TERMINATE, [ ]), ] def procmsg(msg, outbuf): msgbuf = msg._from() - #print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf)) cmd = msgbuf[0] args = [ int.from_bytes(msgbuf[x:x + 4], byteorder='little') for x in range(1, len(msgbuf), @@ -387,7 +601,7 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): class CCodeSD(MockSyncDatagram): async def runner(self): - for expectlen in [ 24, 17, 9, 9, 9 ]: + for expectlen in [ 24, 17, 9, 9, 9, 9 ]: # get message gb = await self.get() r = make_pktbuf(gb) @@ -438,6 +652,8 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): await l.runfor(1, 50) + await l.ping() + await l.terminate() await tsd.drain() @@ -450,3 +666,148 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): # processed. self.assertFalse(exptmsgs) #_debprint('done') + + @timeout(2) + async def test_ccode_newsession(self): + '''This test is to make sure that if an existing session + is running, that a new session can be established, and that + when it does, the old session becomes inactive. + ''' + + _self = self + from ctypes import pointer, sizeof, c_uint8 + + seq = AsyncSequence() + + # seed the RNG + prngseed = b'abc123' + lora_comms.strobe_seed_prng((c_uint8 * + len(prngseed))(*prngseed), len(prngseed)) + + # Create the state for testing + commstate = lora_comms.CommsState() + + # These are the expected messages and their arguments + exptmsgs = [ + (CMD_WAITFOR, [ 30 ]), + (CMD_WAITFOR, [ 70 ]), + (CMD_WAITFOR, [ 40 ]), + (CMD_TERMINATE, [ ]), + ] + def procmsg(msg, outbuf): + msgbuf = msg._from() + cmd = msgbuf[0] + args = [ int.from_bytes(msgbuf[x:x + 4], + byteorder='little') for x in range(1, len(msgbuf), + 4) ] + + if exptmsgs[0] == (cmd, args): + exptmsgs.pop(0) + outbuf[0].pkt[0] = cmd + outbuf[0].pktlen = 1 + else: #pragma: no cover + raise RuntimeError('cmd not found: %d' % cmd) + + # wrap the callback function + cb = lora_comms.process_msgfunc_t(procmsg) + + class FlipMsg(object): + async def flipmsg(self): + # get message + gb = await self.get() + r = make_pktbuf(gb) + + outbytes = bytearray(64) + outbuf = make_pktbuf(outbytes) + + # process the test message + lora_comms.comms_process(commstate, r, + outbuf) + + # pass the reply back + pkt = outbytes[:outbuf.pktlen] + await self.put(pkt) + + # this class always passes messages, this is + # used for the first session. + class CCodeSD1(MockSyncDatagram, FlipMsg): + async def runner(self): + for i in range(3): + await self.flipmsg() + + async with seq.sync(0): + # create bogus message + r = make_pktbuf(b'0'*24) + + outbytes = bytearray(64) + outbuf = make_pktbuf(outbytes) + + # process the bogus message + lora_comms.comms_process(commstate, r, + outbuf) + + # make sure there was not a response + _self.assertEqual(outbuf.pktlen, 0) + + await self.flipmsg() + + # this one is special in that it will pause after the first + # message to ensure that the previous session will continue + # to work, AND that if a new "new" session comes along, it + # will override the previous new session that hasn't been + # confirmed yet. + class CCodeSD2(MockSyncDatagram, FlipMsg): + async def runner(self): + # pass one message from the new session + async with seq.sync(1): + # There might be a missing case + # handled for when the confirmed + # message is generated, but lost. + await self.flipmsg() + + # and the old session is still active + await l.waitfor(70) + + async with seq.sync(2): + for i in range(3): + await self.flipmsg() + + # Generate shared key + shared_key = os.urandom(32) + + # Initialize everything + lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) + + # Create test fixture + tsd = CCodeSD1() + l = LORANode(tsd, shared=shared_key) + + # Send various messages + await l.start() + + await l.waitfor(30) + + # Ensure that a new one can take over + tsd2 = CCodeSD2() + + l2 = LORANode(tsd2, shared=shared_key) + + # Send various messages + await l2.start() + + await l2.waitfor(40) + + await l2.terminate() + + await tsd.drain() + await tsd2.drain() + + # Make sure all messages have been processed + self.assertTrue(tsd.sendq.empty()) + self.assertTrue(tsd.recvq.empty()) + self.assertTrue(tsd2.sendq.empty()) + self.assertTrue(tsd2.recvq.empty()) + + # Make sure all the expected messages have been + # processed. + self.assertFalse(exptmsgs) diff --git a/lora_comms.py b/lora_comms.py index 7c6548f..ca1d3f7 100644 --- a/lora_comms.py +++ b/lora_comms.py @@ -22,10 +22,15 @@ # SUCH DAMAGE. # -from ctypes import Structure, POINTER, CFUNCTYPE, pointer +from ctypes import Structure, POINTER, CFUNCTYPE, pointer, sizeof from ctypes import c_uint8, c_uint16, c_ssize_t, c_size_t, c_uint64, c_int from ctypes import CDLL +class StructureRepr(object): + def __repr__(self): #pragma: no cover + return '%s(%s)' % (self.__class__.__name__, ', '.join('%s=%s' % + (k, getattr(self, k)) for k, v in self._fields_)) + class PktBuf(Structure): _fields_ = [ ('pkt', POINTER(c_uint8)), @@ -63,11 +68,17 @@ _lib._strobe_state_size.restype = c_size_t _lib._strobe_state_size.argtypes = () _strobe_state_u64_cnt = (_lib._strobe_state_size() + 7) // 8 -class CommsState(Structure): +class CommsSession(Structure,StructureRepr): + _fields_ = [ + ('cs_crypto', c_uint64 * _strobe_state_u64_cnt), + ('cs_state', c_int), + ] + +class CommsState(Structure,StructureRepr): _fields_ = [ # The alignment of these may be off - ('cs_state', c_uint64 * _strobe_state_u64_cnt), - ('cs_comm_state', c_int), + ('cs_active', CommsSession), + ('cs_pending', CommsSession), ('cs_start', c_uint64 * _strobe_state_u64_cnt), ('cs_procmsg', process_msgfunc_t), @@ -78,8 +89,15 @@ class CommsState(Structure): ('cs_prevmsgrespbuf', c_uint8 * 64), ] +_lib._comms_state_size.restype = c_size_t +_lib._comms_state_size.argtypes = () + +if _lib._comms_state_size() != sizeof(CommsState): # pragma: no cover + raise RuntimeError('CommsState structure size mismatch!') + for func, ret, args in [ - ('comms_init', None, (POINTER(CommsState), process_msgfunc_t, POINTER(PktBuf))), + ('comms_init', None, (POINTER(CommsState), process_msgfunc_t, + POINTER(PktBuf))), ('comms_process', None, (POINTER(CommsState), PktBuf, POINTER(PktBuf))), ('strobe_seed_prng', None, (POINTER(c_uint8), c_ssize_t)), ]: