@@ -7,7 +7,7 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from noise.connection import NoiseConnection, Keypair
from noise.connection import NoiseConnection, Keypair
#import tracemalloc; tracemalloc.start()
#import tracemalloc; tracemalloc.start(100 )
import argparse
import argparse
import asyncio
import asyncio
@@ -81,15 +81,59 @@ def _makeunix(path):
return 'unix:%s' % path
return 'unix:%s' % path
_allowedparameters = {
'unix': {
'path': str,
},
'tcp': {
'host': str,
'port': int,
},
}
def _parsesockstr(sockstr):
def _parsesockstr(sockstr):
'''Parse a socket string to its parts. If there are no
kwargs (no = after the colon), a dictionary w/ a single
key of default will pass the string after the colon.
default is a reserved keyword and MUST NOT be used.'''
proto, rem = sockstr.split(':', 1)
proto, rem = sockstr.split(':', 1)
return proto, rem
if '=' not in rem:
if proto == 'unix' and rem[0] != '/':
raise ValueError('bare path MUST start w/ a slash (/).')
if proto == 'unix':
args = { 'path': rem }
else:
args = dict(i.split('=', 1) for i in rem.split(','))
try:
allowed = _allowedparameters[proto]
except KeyError:
raise ValueError('unsupported proto: %s' % repr(proto))
extrakeys = args.keys() - allowed.keys()
if extrakeys:
raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
for i in args:
args[i] = allowed[i](args[i])
return proto, args
async def connectsockstr(sockstr):
async def connectsockstr(sockstr):
proto, rem = _parsesockstr(sockstr)
'''Wrapper for asyncio.open_*_connection.'''
reader, writer = await asyncio.open_unix_connection(rem)
proto, args = _parsesockstr(sockstr)
if proto == 'unix':
fun = asyncio.open_unix_connection
elif proto == 'tcp':
fun = asyncio.open_connection
reader, writer = await fun(**args)
return reader, writer
return reader, writer
@@ -101,9 +145,15 @@ async def listensockstr(sockstr, cb):
directly, like: 'proto:value'. This is only allowed when the
directly, like: 'proto:value'. This is only allowed when the
value can unambiguously be determined not to be a param.
value can unambiguously be determined not to be a param.
The cb parameter is passed to asyncio's start_server or related
calls. Per those docs, the cb parameter is calls or scheduled
as a task when a client establishes a connection. It is called
with two arguments, the reader and writer streams. For more
information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
The characters that define 'param' must be all lower case ascii
The characters that define 'param' must be all lower case ascii
characters and may contain an underscore. The first character
characters and may contain an underscore. The first character
must not be and underscore.
must not be an underscore.
Supported protocols:
Supported protocols:
unix:
unix:
@@ -113,11 +163,14 @@ async def listensockstr(sockstr, cb):
slash if it is used as a default parameter.
slash if it is used as a default parameter.
'''
'''
proto, rem = _parsesockstr(sockstr)
proto, args = _parsesockstr(sockstr)
server = await asyncio.start_unix_server(cb, path=rem)
if proto == 'unix':
fun = asyncio.start_unix_server
elif proto == 'tcp':
fun = asyncio.start_server
return server
return await fun(cb, **args)
# !!python makemessagelengths.py
# !!python makemessagelengths.py
_handshakelens = \
_handshakelens = \
@@ -212,8 +265,12 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
if not proto.handshake_finished: # pragma: no cover
if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake')
raise RuntimeError('failed to finish handshake')
reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
try:
reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
except:
wrr.close()
raise
# generate the keys for lengths
# generate the keys for lengths
# XXX - get_handshake_hash is probably not the best option, but
# XXX - get_handshake_hash is probably not the best option, but
@@ -273,7 +330,15 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
finally:
finally:
wrr.write_eof()
wrr.write_eof()
return await asyncio.gather(decses(), encses())
res = await asyncio.gather(decses(), encses())
await wrr.drain() # not sure if needed
wrr.close()
await writer.drain() # not sure if needed
writer.close()
return res
# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
# Slightly modified to timeout and to print trace back when canceled.
# Slightly modified to timeout and to print trace back when canceled.
@@ -296,9 +361,107 @@ def async_test(f):
return wrapper
return wrapper
class Tests_misc(unittest.TestCase):
class Tests_misc(unittest.TestCase):
def test_listensockstr(self):
# XXX write test
pass
def setUp(self):
# setup temporary directory
d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d
self.tempdir = os.path.join(d, 'subdir')
os.mkdir(self.tempdir)
os.chdir(self.tempdir)
def tearDown(self):
#print('td:', time.time())
shutil.rmtree(self.basetempdir)
self.tempdir = None
def test_parsesockstr_bad(self):
badstrs = [
'unix:ff',
'randomnocolon',
'unix:somethingelse=bogus',
'tcp:port=bogus',
]
for i in badstrs:
with self.assertRaises(ValueError,
msg='Should have failed processing: %s' % repr(i)):
_parsesockstr(i)
def test_parsesockstr(self):
results = {
# Not all of these are valid when passed to a *sockstr
# function
'unix:/apath': ('unix', { 'path': '/apath' }),
'unix:path=apath': ('unix', { 'path': 'apath' }),
'tcp:host=apath': ('tcp', { 'host': 'apath' }),
'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
'port': 5 }),
}
for s, r in results.items():
self.assertEqual(_parsesockstr(s), r)
@async_test
async def test_listensockstr_bad(self):
with self.assertRaises(ValueError):
ls = await listensockstr('bogus:some=arg', None)
with self.assertRaises(ValueError):
ls = await connectsockstr('bogus:some=arg')
@async_test
async def test_listenconnectsockstr(self):
msgsent = b'this is a test message'
msgrcv = b'testing message for receive'
# That when a connection is received and receives and sends
async def servconfhandle(rdr, wrr):
msg = await rdr.readexactly(len(msgsent))
self.assertEqual(msg, msgsent)
#print(repr(wrr.get_extra_info('sockname')))
wrr.write(msgrcv)
await wrr.drain()
wrr.close()
return True
# Test listensockstr
for sstr, confun in [
('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
]:
# that listensockstr will bind to the correct path, can call cb
ls = await listensockstr(sstr, servconfhandle)
# that we open a connection to the path
rdr, wrr = await confun()
# and send a message
wrr.write(msgsent)
# and receive the message
rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
self.assertEqual(rcv, msgrcv)
wrr.close()
# Now test that connectsockstr works similarly.
rdr, wrr = await connectsockstr(sstr)
# and send a message
wrr.write(msgsent)
# and receive the message
rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
self.assertEqual(rcv, msgrcv)
wrr.close()
ls.close()
await ls.wait_closed()
def test_genciphfun(self):
def test_genciphfun(self):
enc, dec = _genciphfun(b'0' * 32, b'foobar')
enc, dec = _genciphfun(b'0' * 32, b'foobar')
@@ -531,14 +694,10 @@ class TestMain(unittest.TestCase):
servtargstr = _makeunix(servtargpath)
servtargstr = _makeunix(servtargpath)
# Setup server target listener
# Setup server target listener
ptsock = []
ptsockevent = asyncio.Event()
ptsockevent = asyncio.Event()
def ptsockaccept(reader, writer, ptsock=ptsock):
ptsock.append((reader, writer))
ptsockevent.set()
# Bind to pt listener
# Bind to pt listener
lsock = await listensockstr(servtargstr, ptsockaccept )
lsock = await listensockstr(servtargstr, None)
# Startup the server
# Startup the server
server = await self.run_with_args('server',
server = await self.run_with_args('server',
@@ -566,6 +725,8 @@ class TestMain(unittest.TestCase):
# make sure that we don't get the conenction
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)
await asyncio.wait_for(ptsockevent.wait(), .5)
writer.close()
# Make sure that when the server is terminated
# Make sure that when the server is terminated
server.terminate()
server.terminate()
@@ -577,6 +738,9 @@ class TestMain(unittest.TestCase):
# even the example echo server has this same leak
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)
#self.assertNotIn(b'Task exception was never retrieved', stderr)
lsock.close()
await lsock.wait_closed()
@async_test
@async_test
async def test_end2end(self):
async def test_end2end(self):
# Generate necessar keys
# Generate necessar keys
@@ -654,6 +818,12 @@ class TestMain(unittest.TestCase):
endwrr.write(msg)
endwrr.write(msg)
self.assertEqual(msg, await reader.readexactly(len(msg)))
self.assertEqual(msg, await reader.readexactly(len(msg)))
writer.close()
endwrr.close()
lsock.close()
await lsock.wait_closed()
@async_test
@async_test
async def test_genkey(self):
async def test_genkey(self):
# that it can generate a key
# that it can generate a key
@@ -757,6 +927,8 @@ class TestNoiseFowarder(unittest.TestCase):
with self.assertRaises(ValueError):
with self.assertRaises(ValueError):
await servnf
await servnf
writer.close()
@async_test
@async_test
async def test_server(self):
async def test_server(self):
# Test is plumbed:
# Test is plumbed:
@@ -893,6 +1065,14 @@ class TestNoiseFowarder(unittest.TestCase):
self.assertEqual(nfs[0], [ 'dec', 'enc' ])
self.assertEqual(nfs[0], [ 'dec', 'enc' ])
writer.close()
ptwriter.close()
lsock.close()
ssock.close()
await lsock.wait_closed()
await ssock.wait_closed()
@async_test
@async_test
async def test_serverclient(self):
async def test_serverclient(self):
# plumbing:
# plumbing:
@@ -910,8 +1090,7 @@ class TestNoiseFowarder(unittest.TestCase):
ptsbreader, ptsbwriter = await ptssockbpair
ptsbreader, ptsbwriter = await ptssockbpair
async def validateclientkey(pubkey):
async def validateclientkey(pubkey):
if pubkey != self.client_key_pair[0]:
raise ValueError('invalid key')
self.assertEqual(pubkey, self.client_key_pair[0])
return await ptssockapair
return await ptssockapair
@@ -963,3 +1142,9 @@ class TestNoiseFowarder(unittest.TestCase):
self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)
await ptsbwriter.drain()
await ptcawriter.drain()
ptsbwriter.close()
ptcawriter.close()