| @@ -1,6 +1,11 @@ | |||
| from noise.connection import NoiseConnection, Keypair | |||
| from cryptography.hazmat.primitives.kdf.hkdf import HKDF | |||
| from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes | |||
| from cryptography.hazmat.primitives import hashes | |||
| from twistednoise import genkeypair | |||
| from cryptography.hazmat.backends import default_backend | |||
| import asyncio | |||
| import os.path | |||
| import shutil | |||
| import socket | |||
| @@ -8,23 +13,27 @@ import tempfile | |||
| import threading | |||
| import unittest | |||
| _backend = default_backend() | |||
| def _makeunix(path): | |||
| '''Make a properly formed unix path socket string.''' | |||
| return 'unix:%s' % path | |||
| def _acceptfun(s, fun): | |||
| while True: | |||
| sock = s.accept() | |||
| def _parsesockstr(sockstr): | |||
| proto, rem = sockstr.split(':', 1) | |||
| return proto, rem | |||
| async def connectsockstr(sockstr): | |||
| proto, rem = _parsesockstr(sockstr) | |||
| fun(*sock) | |||
| reader, writer = await asyncio.open_unix_connection(rem) | |||
| def listensocket(sockstr, fun): | |||
| '''Listen for connections on sockstr. When ever a connection | |||
| is accepted, the parameter fun is called with the socket and | |||
| the from address. The return will be a Thread object. Note | |||
| that fun MUST NOT block, as if it does, it will stop accepting | |||
| other connections. | |||
| return reader, writer | |||
| async def listensockstr(sockstr, cb): | |||
| '''Wrapper for asyncio.start_x_server. | |||
| The format of sockstr is: 'proto:param=value[,param2=value2]'. | |||
| If the proto has a default parameter, the value can be used | |||
| @@ -43,29 +52,123 @@ def listensocket(sockstr, fun): | |||
| slash if it is used as a default parameter. | |||
| ''' | |||
| proto, rem = sockstr.split(':', 1) | |||
| proto, rem = _parsesockstr(sockstr) | |||
| s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |||
| s.bind(rem) | |||
| s.listen(-1) | |||
| server = await asyncio.start_unix_server(cb, path=rem) | |||
| thr = threading.Thread(target=_acceptfun, name='accept thread: %s' % repr(sockstr), args=(s, fun)) | |||
| thr.setDaemon(True) | |||
| return server | |||
| thr.start() | |||
| # !!python makemessagelengths.py | |||
| _handshakelens = \ | |||
| [72, 72, 88] | |||
| return thr | |||
| def _genciphfun(hash, ad): | |||
| hkdf = HKDF(algorithm=hashes.SHA256(), length=32, | |||
| salt=b'asdoifjsldkjdsf', info=ad, backend=_backend) | |||
| class NoiseForwarder(object): | |||
| def __init__(self, mode, sock, ): | |||
| nf = NoiseForwarder('resp', self.server_key_pair[1], ssock, pttarg) | |||
| pass | |||
| key = hkdf.derive(hash) | |||
| cipher = Cipher(algorithms.AES(key), modes.ECB(), | |||
| backend=_backend) | |||
| enctor = cipher.encryptor() | |||
| def encfun(data): | |||
| # Returns the two bytes for length | |||
| val = len(data) | |||
| encbytes = enctor.update(data[:16]) | |||
| mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff | |||
| return (val ^ mask).to_bytes(length=2, byteorder='big') | |||
| def decfun(data): | |||
| # takes off the data and returns the total | |||
| # length | |||
| val = int.from_bytes(data[:2], byteorder='big') | |||
| encbytes = enctor.update(data[2:2 + 16]) | |||
| mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff | |||
| return val ^ mask | |||
| return encfun, decfun | |||
| async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): | |||
| rdr, wrr = rdrwrr | |||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||
| proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) | |||
| proto.set_as_responder() | |||
| proto.start_handshake() | |||
| proto.read_message(await rdr.readexactly(_handshakelens[0])) | |||
| wrr.write(proto.write_message()) | |||
| proto.read_message(await rdr.readexactly(_handshakelens[2])) | |||
| if not proto.handshake_finished: # pragma: no cover | |||
| raise RuntimeError('failed to finish handshake') | |||
| # generate the keys for lengths | |||
| _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') | |||
| enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') | |||
| reader, writer = await connectsockstr(ptsockstr) | |||
| async def decses(): | |||
| while True: | |||
| msg = await rdr.readexactly(2 + 16) | |||
| tlen = declenfun(msg) | |||
| rmsg = await rdr.readexactly(tlen - 16) | |||
| tmsg = msg[2:] + rmsg | |||
| writer.write(proto.decrypt(tmsg)) | |||
| await writer.drain() | |||
| async def encses(): | |||
| while True: | |||
| ptmsg = await reader.read(65535 - 16) # largest message | |||
| encmsg = proto.encrypt(ptmsg) | |||
| wrr.write(enclenfun(encmsg)) | |||
| wrr.write(encmsg) | |||
| await wrr.drain() | |||
| r = await asyncio.gather(decses(), encses(), return_exceptions=True) | |||
| print(repr(r)) | |||
| return r | |||
| class TestListenSocket(unittest.TestCase): | |||
| def test_listensocket(self): | |||
| def test_listensockstr(self): | |||
| # XXX write test | |||
| pass | |||
| # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code | |||
| def async_test(f): | |||
| def wrapper(*args, **kwargs): | |||
| coro = asyncio.coroutine(f) | |||
| future = coro(*args, **kwargs) | |||
| loop = asyncio.get_event_loop() | |||
| # timeout after 2 seconds | |||
| loop.run_until_complete(asyncio.wait_for(future, 2)) | |||
| return wrapper | |||
| class Tests_misc(unittest.TestCase): | |||
| def test_genciphfun(self): | |||
| enc, dec = _genciphfun(b'0' * 32, b'foobar') | |||
| msg = b'this is a bunch of data' | |||
| tb = enc(msg) | |||
| self.assertEqual(len(msg), dec(tb + msg)) | |||
| for i in [ 20, 1384, 64000, 23839, 65535 ]: | |||
| msg = os.urandom(i) | |||
| self.assertEqual(len(msg), dec(enc(msg) + msg)) | |||
| class Tests(unittest.TestCase): | |||
| def setUp(self): | |||
| # setup temporary directory | |||
| @@ -82,7 +185,8 @@ class Tests(unittest.TestCase): | |||
| shutil.rmtree(self.basetempdir) | |||
| self.tempdir = None | |||
| def test_server(self): | |||
| @async_test | |||
| async def test_server(self): | |||
| # Path that the server will sit on | |||
| servsockpath = os.path.join(self.tempdir, 'servsock') | |||
| servarg = _makeunix(servsockpath) | |||
| @@ -93,14 +197,17 @@ class Tests(unittest.TestCase): | |||
| # Setup pt target listener | |||
| pttarg = _makeunix(servsockpath) | |||
| ptsock = [] | |||
| def ptsockaccept(sock, frm, ptsock=ptsock): | |||
| ptsock.append(sock) | |||
| def ptsockaccept(reader, writer, ptsock=ptsock): | |||
| ptsock.append((reader, writer)) | |||
| # Bind to pt listener | |||
| lsock = listensocket(pttarg, ptsockaccept) | |||
| lsock = await listensockstr(pttarg, ptsockaccept) | |||
| # Setup server listener | |||
| ssock = listensocket(servarg, lambda x, y: NoiseForwarder('resp', self.server_key_pair[1], x, pttarg)) | |||
| ssock = await listensockstr(servarg, lambda rdr, wrr: NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)) | |||
| # Connect to server | |||
| reader, writer = await connectsockstr(servarg) | |||
| # Create client | |||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||
| @@ -114,3 +221,70 @@ class Tests(unittest.TestCase): | |||
| # Send first message | |||
| message = proto.write_message() | |||
| self.assertEqual(len(message), _handshakelens[0]) | |||
| writer.write(message) | |||
| # Get response | |||
| respmsg = await reader.readexactly(_handshakelens[1]) | |||
| proto.read_message(respmsg) | |||
| # Send final reply | |||
| message = proto.write_message() | |||
| writer.write(message) | |||
| # Make sure handshake has completed | |||
| self.assertTrue(proto.handshake_finished) | |||
| # generate the keys for lengths | |||
| enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') | |||
| _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') | |||
| # write a test message | |||
| ptmsg = b'this is a test message that should be a little in length' | |||
| encmsg = proto.encrypt(ptmsg) | |||
| writer.write(enclenfun(encmsg)) | |||
| writer.write(encmsg) | |||
| # XXX - how to sync? | |||
| await asyncio.sleep(.1) | |||
| # read the test message | |||
| rptmsg = await ptsock[0][0].readexactly(len(ptmsg)) | |||
| self.assertEqual(rptmsg, ptmsg) | |||
| # write a different message | |||
| ptmsg = os.urandom(2843) | |||
| encmsg = proto.encrypt(ptmsg) | |||
| writer.write(enclenfun(encmsg)) | |||
| writer.write(encmsg) | |||
| # XXX - how to sync? | |||
| await asyncio.sleep(.1) | |||
| # read the test message | |||
| rptmsg = await ptsock[0][0].readexactly(len(ptmsg)) | |||
| self.assertEqual(rptmsg, ptmsg) | |||
| # now try the other way | |||
| ptmsg = os.urandom(912) | |||
| ptsock[0][1].write(ptmsg) | |||
| # find out how much we need to read | |||
| encmsg = await reader.readexactly(2 + 16) | |||
| tlen = declenfun(encmsg) | |||
| # read the rest of the message | |||
| rencmsg = await reader.readexactly(tlen - 16) | |||
| tmsg = encmsg[2:] + rencmsg | |||
| rptmsg = proto.decrypt(tmsg) | |||
| self.assertEqual(rptmsg, ptmsg) | |||
| # shut everything down | |||
| ptsock[0][1].write_eof() | |||
| writer.write_eof() | |||
| # XXX - how to sync? | |||
| await asyncio.sleep(1) | |||