Browse Source

add support for starting a new session...

This also adds a check to make sure that the allocated structure matches
the C code so that things won't break.

This breaks out the session state into it's own object... allowing a
common function to handle the state machine...

This also adds a new powerful testing tool.  It's a Synchronization
primitive that will ensure blocks of code run in the expected order,
and only run when the previous block has fully run...  This helps
ensure ordering between many tasks, to cause race conditions that
would otherwise be hard to cause..
irr_shared
John-Mark Gurney 3 years ago
parent
commit
4024d1d5e4
4 changed files with 517 additions and 46 deletions
  1. +101
    -34
      comms.c
  2. +27
    -2
      comms.h
  3. +366
    -5
      lora.py
  4. +23
    -5
      lora_comms.py

+ 101
- 34
comms.c View File

@@ -44,6 +44,13 @@ comms_pktbuf_equal(struct pktbuf a, struct pktbuf b)
return memcmp(a.pkt, b.pkt, a.pktlen) == 0;
}

size_t
_comms_state_size()
{

return sizeof(struct comms_state);
}

size_t
_strobe_state_size()
{
@@ -56,7 +63,6 @@ comms_init(struct comms_state *cs, process_msgfunc_t pmf, struct pktbuf *shared)
{

*cs = (struct comms_state){
.cs_comm_state = COMMS_WAIT_REQUEST,
.cs_procmsg = pmf,
};

@@ -66,45 +72,44 @@ comms_init(struct comms_state *cs, process_msgfunc_t pmf, struct pktbuf *shared)
strobe_key(&cs->cs_start, SYM_KEY, shared->pkt, shared->pktlen);

/* copy starting state over to initial state */
cs->cs_state = cs->cs_start;
cs->cs_active = (struct comms_session){
.cs_crypto = cs->cs_start,
.cs_state = COMMS_WAIT_REQUEST,
};
cs->cs_pending = cs->cs_active;
}

#define CONFIRMED_STR_BASE "confirmed"
#define CONFIRMED_STR ((const uint8_t *)CONFIRMED_STR_BASE)
#define CONFIRMED_STR_LEN (sizeof(CONFIRMED_STR_BASE) - 1)

/*
* encrypted data to be processed is passed in via pbin.
*
* The pktbuf pointed to by pbout contains the buffer that a [encrypted]
* response will be written to. The length needs to be updated, where 0
* means no reply.
*/
void
comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout)
static void
_comms_process_session(struct comms_state *cs, struct comms_session *sess, struct pktbuf pbin, struct pktbuf *pbout)
{
strobe_s tmp;
uint8_t buf[64] = {};
struct pktbuf pbmsg, pbrep;
ssize_t cnt, ret, msglen;

/* if the current msg matches the previous */
if (comms_pktbuf_equal(pbin, cs->cs_prevmsg)) {
/* send the previous response */
pbout->pktlen = cs->cs_prevmsgresp.pktlen;
memcpy(pbout->pkt, cs->cs_prevmsgresp.pkt, pbout->pktlen);
return;
}
/* save the state incase the message is bad */
tmp = sess->cs_crypto;

strobe_attach_buffer(&cs->cs_state, pbin.pkt, pbin.pktlen);
strobe_attach_buffer(&sess->cs_crypto, pbin.pkt, pbin.pktlen);

cnt = strobe_get(&cs->cs_state, APP_CIPHERTEXT, buf, pbin.pktlen -
cnt = strobe_get(&sess->cs_crypto, APP_CIPHERTEXT, buf, pbin.pktlen -
MAC_LEN);
msglen = cnt;

cnt = strobe_get(&cs->cs_state, MAC, pbin.pkt +
cnt = strobe_get(&sess->cs_crypto, MAC, pbin.pkt +
(pbin.pktlen - MAC_LEN), MAC_LEN);

/* XXX - cnt != MAC_LEN test case */
/* MAC check failed */
if (cnt == -1) {
/* restore the previous state */
sess->cs_crypto = tmp;
pbout->pktlen = 0;
return;
}

/*
* if we have arrived here, MAC has been verified, and buf now
@@ -112,29 +117,29 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout)
*/

/* attach the buffer for output */
strobe_attach_buffer(&cs->cs_state, pbout->pkt, pbout->pktlen);
strobe_attach_buffer(&sess->cs_crypto, pbout->pkt, pbout->pktlen);

ret = 0;
switch (cs->cs_comm_state) {
switch (sess->cs_state) {
case COMMS_WAIT_REQUEST:
/* XXX - reqreset check */

bare_strobe_randomize(buf, CHALLENGE_LEN);
ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, buf,
ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, buf,
CHALLENGE_LEN);
ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN);
ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN);

strobe_operate(&cs->cs_state, RATCHET, NULL, 32);
strobe_operate(&sess->cs_crypto, RATCHET, NULL, 32);

cs->cs_comm_state = COMMS_WAIT_CONFIRM;
sess->cs_state = COMMS_WAIT_CONFIRM;
break;

case COMMS_WAIT_CONFIRM:
/* XXX - confirm check */
ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, CONFIRMED_STR,
ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, CONFIRMED_STR,
CONFIRMED_STR_LEN);
ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN);
cs->cs_comm_state = COMMS_PROCESS_MSGS;
ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN);
sess->cs_state = COMMS_PROCESS_MSGS;
break;

case COMMS_PROCESS_MSGS: {
@@ -150,9 +155,9 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout)

cs->cs_procmsg(pbmsg, &pbrep);

ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, repbuf,
ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, repbuf,
pbrep.pktlen);
ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN);
ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN);

break;
}
@@ -161,8 +166,36 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout)
/* set the output buffer length */
pbout->pktlen = ret;

if (ret != 0) {
}

/*
* encrypted data to be processed is passed in via pbin.
*
* The pktbuf pointed to by pbout contains the buffer that a [encrypted]
* response will be written to. The length needs to be updated, where 0
* means no reply.
*/
void
comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout)
{
struct pktbuf pbouttmp;

/* if the current msg matches the previous */
if (comms_pktbuf_equal(pbin, cs->cs_prevmsg)) {
/* send the previous response */
pbout->pktlen = cs->cs_prevmsgresp.pktlen;
memcpy(pbout->pkt, cs->cs_prevmsgresp.pkt, pbout->pktlen);
return;
}

/* try to use the active session */
pbouttmp = *pbout;
_comms_process_session(cs, &cs->cs_active, pbin, &pbouttmp);

if (pbouttmp.pktlen != 0) {
retmsg:
/* we accepted a new message store it */
*pbout = pbouttmp;

/* store the req */
cs->cs_prevmsg.pkt = cs->cs_prevmsgbuf;
@@ -173,5 +206,39 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout)
cs->cs_prevmsgresp.pkt = cs->cs_prevmsgrespbuf;
cs->cs_prevmsgresp.pktlen = pbout->pktlen;
memcpy(cs->cs_prevmsgresp.pkt, pbout->pkt, pbout->pktlen);
} else {
/* active session didn't work, try cs_pending */

pbouttmp = *pbout;
_comms_process_session(cs, &cs->cs_pending, pbin, &pbouttmp);

if (cs->cs_pending.cs_state == COMMS_PROCESS_MSGS) {
/* new active state */
cs->cs_active = cs->cs_pending;
cs->cs_pending = (struct comms_session){
.cs_crypto = cs->cs_start,
.cs_state = COMMS_WAIT_REQUEST,
};
goto retmsg;
}

/* pending session didn't work, maybe new */
struct comms_session tmpsess;

tmpsess = (struct comms_session){
.cs_crypto = cs->cs_start,
.cs_state = COMMS_WAIT_REQUEST,
};

pbouttmp = *pbout;
_comms_process_session(cs, &tmpsess, pbin, &pbouttmp);
if (tmpsess.cs_state == COMMS_WAIT_CONFIRM) {
/* new request for session */
cs->cs_pending = tmpsess;
*pbout = pbouttmp;
} else {
/* no packet to reply with */
pbout->pktlen = 0;
}
}
}

+ 27
- 2
comms.h View File

@@ -45,9 +45,33 @@ enum comm_state {
COMMS_PROCESS_MSGS,
};

struct comms_session {
strobe_s cs_crypto;
enum comm_state cs_state;
};

/*
* Each message will be passed to each state.
*
* cs_active can be in any state.
* cs_pending can only be in a _WAIT_* state.
*
* When cs_pending advances to _PROCESS_MSGS, it will
* replace cs_active, and cs_pending w/ be copied from cache
* and set to _WAIT_REQUEST.
*
* If any message was not processed by the first to, a new session
* will be attempted w/ the _start crypto state, and if it progresses
* to _WAIT_CONFIG, it will replace cs_pending.
*
* We don't have to save the reply from a new session, because if the
* reply gets lost, the initiator will send the request again and we'll
* restart the session.
*/
struct comms_state {
strobe_s cs_state;
enum comm_state cs_comm_state;
struct comms_session cs_active; /* current active session */
struct comms_session cs_pending; /* current pending session */

strobe_s cs_start; /* special starting state cache */

process_msgfunc_t cs_procmsg;
@@ -60,6 +84,7 @@ struct comms_state {
};

size_t _strobe_state_size();
size_t _comms_state_size();

void comms_init(struct comms_state *, process_msgfunc_t, struct pktbuf *);
void comms_process(struct comms_state *, struct pktbuf, struct pktbuf *);

+ 366
- 5
lora.py View File

@@ -23,7 +23,9 @@
#

import asyncio
import contextlib
import functools
import itertools
import os
import unittest

@@ -42,6 +44,7 @@ CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
# The follow commands are queue up, but will be acknoledged when queued
CMD_WAITFOR = 2 # arg: (length): waits for length seconds
CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
CMD_PING = 4 # arg: (): a no op command

class LORANode(object):
'''Implement a LORANode initiator.'''
@@ -62,7 +65,8 @@ class LORANode(object):
pkt = await self.sendrecvvalid(b'confirm')

if pkt != b'confirmed':
raise RuntimeError
raise RuntimeError('got invalid response: %s' %
repr(pkt))

async def sendrecvvalid(self, msg):
msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN)
@@ -73,6 +77,10 @@ class LORANode(object):
resp = await self.sd.sendtillrecv(msg, 1)
#_debprint('got:', resp)

# skip empty messages
if len(resp) == 0:
continue

try:
decmsg = self.st.recv_enc(resp[:-self.MAC_LEN])
self.st.recv_mac(resp[-self.MAC_LEN:])
@@ -110,6 +118,9 @@ class LORANode(object):
async def runfor(self, chan, length):
return await self._sendcmd(CMD_RUNFOR, chan, length)

async def ping(self):
return await self._sendcmd(CMD_PING)

async def terminate(self):
return await self._sendcmd(CMD_TERMINATE)

@@ -127,7 +138,6 @@ class SyncDatagram(object):
raise NotImplementedError

async def send(self, data): #pragma: no cover
'''Send a datagram.'''

raise NotImplementedError

@@ -153,7 +163,6 @@ class MockSyncDatagram(SyncDatagram):
def __init__(self):
self.sendq = asyncio.Queue()
self.recvq = asyncio.Queue()
self.task = None
self.task = asyncio.create_task(self.runner())

self.get = self.sendq.get
@@ -225,6 +234,211 @@ def _debprint(*args): # pragma: no cover
*args)
sys.stdout.flush()

class AsyncSequence(object):
'''
Object used for sequencing async functions. To use, use the
asynchronous context manager created by the sync method. For
example:
seq = AsyncSequence()
async func1():
async with seq.sync(1):
second_fun()

async func2():
async with seq.sync(0):
first_fun()

This will make sure that function first_fun is run before running
the function second_fun. If a previous block raises an Exception,
it will be passed up, and all remaining blocks (and future ones)
will raise a CancelledError to help ensure that any tasks are
properly cleaned up.
'''

def __init__(self, positerfactory=lambda: itertools.count()):
'''The argument positerfactory, is a factory that will
create an iterator that will be used for the values that
are passed to the sync method.'''

self.positer = positerfactory()
self.token = object()
self.die = False
self.waiting = {
next(self.positer): self.token
}

@contextlib.asynccontextmanager
async def sync(self, pos):
'''An async context manager that will be run when it's
turn arrives. It will only run when all the previous
items in the iterator has been successfully run.'''

if self.die:
raise asyncio.CancelledError('seq cancelled')

if pos in self.waiting:
if self.waiting[pos] is not self.token:
raise RuntimeError('pos already waiting!')
else:
fut = asyncio.Future()
self.waiting[pos] = fut
await fut

# our time to shine!
del self.waiting[pos]

try:
yield None
except Exception as e:
# if we got an exception, things went pear shaped,
# shut everything down, and any future calls.

#_debprint('dieing...', repr(e))
self.die = True

# cancel existing blocks
while self.waiting:
k, v = self.waiting.popitem()
#_debprint('canceling: %s' % repr(k))
if v is self.token:
continue

# for Python 3.9:
# msg='pos %s raised exception: %s' %
# (repr(pos), repr(e))
v.cancel()

# populate real exception up
raise
else:
# handle next
nextpos = next(self.positer)

if nextpos in self.waiting:
#_debprint('np:', repr(self), nextpos,
# repr(self.waiting[nextpos]))
self.waiting[nextpos].set_result(None)
else:
self.waiting[nextpos] = self.token

class TestSequencing(unittest.IsolatedAsyncioTestCase):
@timeout(2)
async def test_seq_alreadywaiting(self):
waitseq = AsyncSequence()

seq = AsyncSequence()

async def fun1():
async with waitseq.sync(1):
pass

async def fun2():
async with seq.sync(1):
async with waitseq.sync(1): # pragma: no cover
pass

task1 = asyncio.create_task(fun1())
task2 = asyncio.create_task(fun2())

# spin things to make sure things advance
await asyncio.sleep(0)

async with seq.sync(0):
pass

with self.assertRaises(RuntimeError):
await task2

async with waitseq.sync(0):
pass

await task1

@timeout(2)
async def test_seqexc(self):
seq = AsyncSequence()

excseq = AsyncSequence()

async def excfun1():
async with seq.sync(1):
pass

async with excseq.sync(0):
raise ValueError('foo')

# that a block that enters first, but runs after
# raises an exception
async def excfun2():
async with seq.sync(0):
pass

async with excseq.sync(1): # pragma: no cover
pass

# that a block that enters after, raises an
# exception
async def excfun3():
async with seq.sync(2):
pass

async with excseq.sync(2): # pragma: no cover
pass

task1 = asyncio.create_task(excfun1())
task2 = asyncio.create_task(excfun2())
task3 = asyncio.create_task(excfun3())

with self.assertRaises(ValueError):
await task1

with self.assertRaises(asyncio.CancelledError):
await task2

with self.assertRaises(asyncio.CancelledError):
await task3

@timeout(2)
async def test_seq(self):
# test that a seq object when created
seq = AsyncSequence(lambda: itertools.count(1))

col = []

async def fun1():
async with seq.sync(1):
col.append(1)

async with seq.sync(2):
col.append(2)

async with seq.sync(4):
col.append(4)

async def fun2():
async with seq.sync(3):
col.append(3)

async with seq.sync(6):
col.append(6)

async def fun3():
async with seq.sync(5):
col.append(5)

# and various functions are run
task1 = asyncio.create_task(fun1())
task2 = asyncio.create_task(fun2())
task3 = asyncio.create_task(fun3())

# and the functions complete
await task3
await task2
await task1

# that the order they ran in was correct
self.assertEqual(col, list(range(1, 7)))

class TestLORANode(unittest.IsolatedAsyncioTestCase):
@timeout(2)
async def test_lora(self):
@@ -365,11 +579,11 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
exptmsgs = [
(CMD_WAITFOR, [ 30 ]),
(CMD_RUNFOR, [ 1, 50 ]),
(CMD_PING, [ ]),
(CMD_TERMINATE, [ ]),
]
def procmsg(msg, outbuf):
msgbuf = msg._from()
#print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf))
cmd = msgbuf[0]
args = [ int.from_bytes(msgbuf[x:x + 4],
byteorder='little') for x in range(1, len(msgbuf),
@@ -387,7 +601,7 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):

class CCodeSD(MockSyncDatagram):
async def runner(self):
for expectlen in [ 24, 17, 9, 9, 9 ]:
for expectlen in [ 24, 17, 9, 9, 9, 9 ]:
# get message
gb = await self.get()
r = make_pktbuf(gb)
@@ -438,6 +652,8 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):

await l.runfor(1, 50)

await l.ping()

await l.terminate()

await tsd.drain()
@@ -450,3 +666,148 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
# processed.
self.assertFalse(exptmsgs)
#_debprint('done')

@timeout(2)
async def test_ccode_newsession(self):
'''This test is to make sure that if an existing session
is running, that a new session can be established, and that
when it does, the old session becomes inactive.
'''

_self = self
from ctypes import pointer, sizeof, c_uint8

seq = AsyncSequence()

# seed the RNG
prngseed = b'abc123'
lora_comms.strobe_seed_prng((c_uint8 *
len(prngseed))(*prngseed), len(prngseed))

# Create the state for testing
commstate = lora_comms.CommsState()

# These are the expected messages and their arguments
exptmsgs = [
(CMD_WAITFOR, [ 30 ]),
(CMD_WAITFOR, [ 70 ]),
(CMD_WAITFOR, [ 40 ]),
(CMD_TERMINATE, [ ]),
]
def procmsg(msg, outbuf):
msgbuf = msg._from()
cmd = msgbuf[0]
args = [ int.from_bytes(msgbuf[x:x + 4],
byteorder='little') for x in range(1, len(msgbuf),
4) ]

if exptmsgs[0] == (cmd, args):
exptmsgs.pop(0)
outbuf[0].pkt[0] = cmd
outbuf[0].pktlen = 1
else: #pragma: no cover
raise RuntimeError('cmd not found: %d' % cmd)

# wrap the callback function
cb = lora_comms.process_msgfunc_t(procmsg)

class FlipMsg(object):
async def flipmsg(self):
# get message
gb = await self.get()
r = make_pktbuf(gb)

outbytes = bytearray(64)
outbuf = make_pktbuf(outbytes)

# process the test message
lora_comms.comms_process(commstate, r,
outbuf)

# pass the reply back
pkt = outbytes[:outbuf.pktlen]
await self.put(pkt)

# this class always passes messages, this is
# used for the first session.
class CCodeSD1(MockSyncDatagram, FlipMsg):
async def runner(self):
for i in range(3):
await self.flipmsg()

async with seq.sync(0):
# create bogus message
r = make_pktbuf(b'0'*24)

outbytes = bytearray(64)
outbuf = make_pktbuf(outbytes)

# process the bogus message
lora_comms.comms_process(commstate, r,
outbuf)

# make sure there was not a response
_self.assertEqual(outbuf.pktlen, 0)

await self.flipmsg()

# this one is special in that it will pause after the first
# message to ensure that the previous session will continue
# to work, AND that if a new "new" session comes along, it
# will override the previous new session that hasn't been
# confirmed yet.
class CCodeSD2(MockSyncDatagram, FlipMsg):
async def runner(self):
# pass one message from the new session
async with seq.sync(1):
# There might be a missing case
# handled for when the confirmed
# message is generated, but lost.
await self.flipmsg()

# and the old session is still active
await l.waitfor(70)

async with seq.sync(2):
for i in range(3):
await self.flipmsg()

# Generate shared key
shared_key = os.urandom(32)

# Initialize everything
lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))

# Create test fixture
tsd = CCodeSD1()
l = LORANode(tsd, shared=shared_key)

# Send various messages
await l.start()

await l.waitfor(30)

# Ensure that a new one can take over
tsd2 = CCodeSD2()

l2 = LORANode(tsd2, shared=shared_key)

# Send various messages
await l2.start()

await l2.waitfor(40)

await l2.terminate()

await tsd.drain()
await tsd2.drain()

# Make sure all messages have been processed
self.assertTrue(tsd.sendq.empty())
self.assertTrue(tsd.recvq.empty())
self.assertTrue(tsd2.sendq.empty())
self.assertTrue(tsd2.recvq.empty())

# Make sure all the expected messages have been
# processed.
self.assertFalse(exptmsgs)

+ 23
- 5
lora_comms.py View File

@@ -22,10 +22,15 @@
# SUCH DAMAGE.
#

from ctypes import Structure, POINTER, CFUNCTYPE, pointer
from ctypes import Structure, POINTER, CFUNCTYPE, pointer, sizeof
from ctypes import c_uint8, c_uint16, c_ssize_t, c_size_t, c_uint64, c_int
from ctypes import CDLL

class StructureRepr(object):
def __repr__(self): #pragma: no cover
return '%s(%s)' % (self.__class__.__name__, ', '.join('%s=%s' %
(k, getattr(self, k)) for k, v in self._fields_))

class PktBuf(Structure):
_fields_ = [
('pkt', POINTER(c_uint8)),
@@ -63,11 +68,17 @@ _lib._strobe_state_size.restype = c_size_t
_lib._strobe_state_size.argtypes = ()
_strobe_state_u64_cnt = (_lib._strobe_state_size() + 7) // 8

class CommsState(Structure):
class CommsSession(Structure,StructureRepr):
_fields_ = [
('cs_crypto', c_uint64 * _strobe_state_u64_cnt),
('cs_state', c_int),
]

class CommsState(Structure,StructureRepr):
_fields_ = [
# The alignment of these may be off
('cs_state', c_uint64 * _strobe_state_u64_cnt),
('cs_comm_state', c_int),
('cs_active', CommsSession),
('cs_pending', CommsSession),
('cs_start', c_uint64 * _strobe_state_u64_cnt),
('cs_procmsg', process_msgfunc_t),

@@ -78,8 +89,15 @@ class CommsState(Structure):
('cs_prevmsgrespbuf', c_uint8 * 64),
]

_lib._comms_state_size.restype = c_size_t
_lib._comms_state_size.argtypes = ()

if _lib._comms_state_size() != sizeof(CommsState): # pragma: no cover
raise RuntimeError('CommsState structure size mismatch!')

for func, ret, args in [
('comms_init', None, (POINTER(CommsState), process_msgfunc_t, POINTER(PktBuf))),
('comms_init', None, (POINTER(CommsState), process_msgfunc_t,
POINTER(PktBuf))),
('comms_process', None, (POINTER(CommsState), PktBuf, POINTER(PktBuf))),
('strobe_seed_prng', None, (POINTER(c_uint8), c_ssize_t)),
]:


Loading…
Cancel
Save