Browse Source

break out some util functions, implement multicast adapter for LORANode...

irr_shared
John-Mark Gurney 3 years ago
parent
commit
08e6f93992
3 changed files with 191 additions and 28 deletions
  1. +161
    -26
      lora.py
  2. +2
    -2
      multicast.py
  3. +28
    -0
      util.py

+ 161
- 26
lora.py View File

@@ -34,6 +34,8 @@ from Strobe.Strobe import AuthenticationFailed

import lora_comms
from lora_comms import make_pktbuf
import multicast
from util import *

domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'

@@ -127,9 +129,6 @@ class LORANode(object):
class SyncDatagram(object):
'''Base interface for a more simple synchronous interface.'''

def __init__(self): #pragma: no cover
pass

async def recv(self, timeout=None): #pragma: no cover
'''Receive a datagram. If timeout is not None, wait that many
seconds, and if nothing is received in that time, raise an
@@ -154,6 +153,56 @@ class SyncDatagram(object):
except TimeoutError:
pass

class MulticastSyncDatagram(SyncDatagram):
'''
An implementation of SyncDatagram that uses the provided
multicast address maddr as the source/sink of the packets.

Note that once created, the start coroutine needs to be
await'd before being passed to a LORANode so that everything
is running.
'''

# Note: sent packets will be received. A similar method to
# what was done in multicast.{to,from}_loragw could be done
# here as well, that is passing in a set of packets to not
# pass back up.

def __init__(self, maddr):
self.maddr = maddr
self._ignpkts = set()

async def start(self):
self.mr = await multicast.create_multicast_receiver(self.maddr)
self.mt = await multicast.create_multicast_transmitter(
self.maddr)

async def _recv(self):
while True:
pkt = await self.mr.recv()
pkt = pkt[0]
if pkt not in self._ignpkts:
return pkt

self._ignpkts.remove(pkt)

async def recv(self, timeout=None): #pragma: no cover
r = await asyncio.wait_for(self._recv(), timeout=timeout)

return r

async def send(self, data): #pragma: no cover
self._ignpkts.add(bytes(data))
await self.mt.send(data)

def close(self):
'''Shutdown communications.'''

self.mr.close()
self.mr = None
self.mt.close()
self.mt = None

class MockSyncDatagram(SyncDatagram):
'''A testing version of SyncDatagram. Define a method runner which
implements part of the sequence. In the function, await on either
@@ -211,29 +260,6 @@ class TestSyncData(unittest.IsolatedAsyncioTestCase):
self.assertEqual(r, b'a')
self.assertEqual(ms.sendq, [ b'foo', b'foo' ])

def timeout(timeout):
def timeout_wrapper(fun):
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
return await asyncio.wait_for(fun(*args, **kwargs),
timeout)

return wrapper

return timeout_wrapper

def _debprint(*args): # pragma: no cover
import traceback, sys, os.path
st = traceback.extract_stack(limit=2)[0]

sep = ''
if args:
sep = ':'

print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep),
*args)
sys.stdout.flush()

class AsyncSequence(object):
'''
Object used for sequencing async functions. To use, use the
@@ -267,6 +293,10 @@ class AsyncSequence(object):
next(self.positer): self.token
}

async def simpsync(self, pos):
async with self.sync(pos):
pass

@contextlib.asynccontextmanager
async def sync(self, pos):
'''An async context manager that will be run when it's
@@ -877,3 +907,108 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
# Make sure all the expected messages have been
# processed.
self.assertFalse(exptmsgs)

class TestLoRaNodeMulticast(unittest.IsolatedAsyncioTestCase):
# see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
maddr = ('224.0.0.198', 48542)

@timeout(2)
async def test_multisyncdgram(self):
# Test the implementation of the multicast version of
# SyncDatagram

_self = self
from ctypes import c_uint8

# seed the RNG
prngseed = b'abc123'
lora_comms.strobe_seed_prng((c_uint8 *
len(prngseed))(*prngseed), len(prngseed))

# Create the state for testing
commstate = lora_comms.CommsState()

# These are the expected messages and their arguments
exptmsgs = [
(CMD_WAITFOR, [ 30 ]),
(CMD_PING, [ ]),
(CMD_TERMINATE, [ ]),
]
def procmsg(msg, outbuf):
msgbuf = msg._from()
cmd = msgbuf[0]
args = [ int.from_bytes(msgbuf[x:x + 4],
byteorder='little') for x in range(1, len(msgbuf),
4) ]

if exptmsgs[0] == (cmd, args):
exptmsgs.pop(0)
outbuf[0].pkt[0] = cmd
outbuf[0].pktlen = 1
else: #pragma: no cover
raise RuntimeError('cmd not found')

# wrap the callback function
cb = lora_comms.process_msgfunc_t(procmsg)

# Generate shared key
shared_key = os.urandom(32)

# Initialize everything
lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))

# create the object we are testing
msd = MulticastSyncDatagram(self.maddr)

seq = AsyncSequence()

async def clienttask():
mr = await multicast.create_multicast_receiver(
self.maddr)
mt = await multicast.create_multicast_transmitter(
self.maddr)

try:
# make sure the above threads are running
await seq.simpsync(0)

while True:
pkt = await mr.recv()
msg = pkt[0]

out = lora_comms.comms_process_wrap(
commstate, msg)

if out:
await mt.send(out)
finally:
mr.close()
mt.close()

task = asyncio.create_task(clienttask())

# start it
await msd.start()

# pass it to a node
l = LORANode(msd, shared=shared_key)

await seq.simpsync(1)

# Send various messages
await l.start()

await l.waitfor(30)

await l.ping()

await l.terminate()

# shut things down
ln = None
msd.close()

task.cancel()

with self.assertRaises(asyncio.CancelledError):
await task

+ 2
- 2
multicast.py View File

@@ -4,7 +4,7 @@ import socket
import struct
import unittest

from lora import timeout
from util import *

# This function based upon code from:
# https://gist.github.com/petrdvor/e802bec72e78ace061ab9d4469418fae#file-async-multicast-receiver-server-py-L54-L72
@@ -77,7 +77,7 @@ class TestMulticast(unittest.IsolatedAsyncioTestCase):
@timeout(2)
async def test_multicast(self):
# see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
maddr = ('224.0.0.%d' % random.randint(151, 250), 3485)
maddr = ('224.0.0.199', 3485)

l1 = await create_multicast_receiver(maddr)
l2 = await create_multicast_receiver(maddr)


+ 28
- 0
util.py View File

@@ -0,0 +1,28 @@
import asyncio
import functools

__all__ = [ 'timeout', '_debprint' ]

def timeout(timeout):
def timeout_wrapper(fun):
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
return await asyncio.wait_for(fun(*args, **kwargs),
timeout)

return wrapper

return timeout_wrapper

def _debprint(*args): # pragma: no cover
import traceback, sys, os.path
st = traceback.extract_stack(limit=2)[0]

sep = ''
if args:
sep = ':'

print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep),
*args)
sys.stdout.flush()


Loading…
Cancel
Save