# Copyright 2021 John-Mark Gurney. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # import os import unittest from ctypes import Structure, POINTER, CFUNCTYPE, pointer, sizeof from ctypes import c_uint8, c_uint16, c_ssize_t, c_size_t, c_uint64, c_int from ctypes import CDLL class StructureRepr(object): def __repr__(self): #pragma: no cover return '%s(%s)' % (self.__class__.__name__, ', '.join('%s=%s' % (k, getattr(self, k)) for k, v in self._fields_)) class PktBuf(Structure): _fields_ = [ ('pkt', POINTER(c_uint8)), ('pktlen', c_uint16), ] def _from(self): return bytes(self.pkt[:self.pktlen]) def __repr__(self): #pragma: no cover return 'PktBuf(pkt=%s, pktlen=%s)' % (repr(self._from()), self.pktlen) def make_pktbuf(s): pb = PktBuf() if isinstance(s, bytearray): obj = s pb.pkt = pointer(c_uint8.from_buffer(s)) else: obj = (c_uint8 * len(s))(*s) pb.pkt = obj pb.pktlen = len(s) pb._make_pktbuf_ref = (obj, s) return pb process_msgfunc_t = CFUNCTYPE(None, PktBuf, POINTER(PktBuf)) try: _lib = CDLL('libsyote_test.dylib') except OSError: _lib = None if _lib is not None: _lib._strobe_state_size.restype = c_size_t _lib._strobe_state_size.argtypes = () _strobe_state_u64_cnt = (_lib._strobe_state_size() + 7) // 8 else: _strobe_state_u64_cnt = 1 class CommsSession(Structure,StructureRepr): _fields_ = [ ('cs_crypto', c_uint64 * _strobe_state_u64_cnt), ('cs_state', c_int), ] EC_PUBLIC_BYTES = 32 EC_PRIVATE_BYTES = 32 class CommsState(Structure,StructureRepr): _fields_ = [ # The alignment of these may be off ('cs_active', CommsSession), ('cs_pending', CommsSession), ('cs_respkey', c_uint8 * EC_PRIVATE_BYTES), ('cs_resppubkey', c_uint8 * EC_PUBLIC_BYTES), ('cs_initpubkey', c_uint8 * EC_PUBLIC_BYTES), ('cs_start', CommsSession), ('cs_procmsg', process_msgfunc_t), ('cs_prevmsg', PktBuf), ('cs_prevmsgresp', PktBuf), ('cs_prevmsgbuf', c_uint8 * 64), ('cs_prevmsgrespbuf', c_uint8 * 64), ] if _lib is not None: _lib._comms_state_size.restype = c_size_t _lib._comms_state_size.argtypes = () if _lib._comms_state_size() != sizeof(CommsState): # pragma: no cover raise RuntimeError('CommsState structure size mismatch!') X25519_BASE_POINT = (c_uint8 * (256//8)).in_dll(_lib, 'X25519_BASE_POINT') for func, ret, args in [ ('comms_init', c_int, (POINTER(CommsState), process_msgfunc_t, POINTER(PktBuf), POINTER(PktBuf), POINTER(PktBuf))), ('comms_process', None, (POINTER(CommsState), PktBuf, POINTER(PktBuf))), ('strobe_seed_prng', None, (POINTER(c_uint8), c_ssize_t)), ('x25519', c_int, (c_uint8 * EC_PUBLIC_BYTES, c_uint8 * EC_PRIVATE_BYTES, c_uint8 * EC_PUBLIC_BYTES, c_int)), ]: f = getattr(_lib, func) f.restype = ret f.argtypes = args locals()[func] = f def x25519_wrap(out, scalar, base, clamp): outptr = (c_uint8 * EC_PUBLIC_BYTES).from_buffer_copy(out) scalarptr = (c_uint8 * EC_PRIVATE_BYTES).from_buffer_copy(scalar) baseptr = (c_uint8 * EC_PRIVATE_BYTES).from_buffer_copy(base) r = x25519(outptr, scalarptr, baseptr, clamp) if r != 0: raise RuntimeError('x25519 failed') return bytes(outptr) def x25519_genkey(): return os.urandom(EC_PRIVATE_BYTES) def x25519_base(scalar, clamp): out = bytearray(EC_PUBLIC_BYTES) outptr = (c_uint8 * EC_PUBLIC_BYTES).from_buffer(out) scalarptr = (c_uint8 * EC_PRIVATE_BYTES).from_buffer_copy(scalar) r = x25519(outptr, scalarptr, X25519_BASE_POINT, clamp) if r != 0: raise RuntimeError('x25519 failed') return bytes(out) class X25519: '''Class to wrap the x25519 functions into something a bit more usable. This provides better key ingestion and better support for other key formats. Use either the gen method to generate a random key, or the frombytes method. a = X25519.gen() b = X25519.gen() a.dh(b.getpub()) == b.dh(a.getpub()) That is, each party generates a key, sends their public part to the other party, and then uses their received public part as an argument to the dh method. The resulting value will be shared between the two parties. ''' def __init__(self, key): self.privkey = key self.pubkey = x25519_base(key, 1) def dh(self, pub): '''Perform a DH operation using the public part pub.''' return x25519_wrap(self.pubkey, self.privkey, pub, 1) def getpub(self): '''Get the public part of the key. This is to be sent to the other party for key exchange.''' return self.pubkey def getpriv(self): return self.privkey @classmethod def gen(cls): '''Generate a random X25519 key.''' return cls(x25519_genkey()) @classmethod def frombytes(cls, key): '''Generate an X25519 key from 32 bytes.''' return cls(key) def comms_process_wrap(state, input): '''A wrapper around comms_process that converts the argument into the buffer, and the returns the message as a bytes string. ''' inpkt = make_pktbuf(input) outbytes = bytearray(64) outbuf = make_pktbuf(outbytes) comms_process(state, inpkt, outbuf) return outbuf._from() class TestX25519(unittest.TestCase): PUBLIC_BYTES = EC_PUBLIC_BYTES PRIVATE_BYTES = EC_PRIVATE_BYTES def test_class(self): key = X25519.gen() pubkey = key.getpub() privkey = key.getpriv() apubkey = x25519_base(privkey, 1) self.assertEqual(apubkey, pubkey) self.assertEqual(X25519.frombytes(privkey).getpub(), pubkey) with self.assertRaises(ValueError): X25519(b'0'*31) def test_rfc7748_6_1(self): # KAT from https://datatracker.ietf.org/doc/html/rfc7748#section-6.1 apriv = bytes.fromhex('77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a') akey = X25519(apriv) self.assertEqual(akey.getpub(), bytes.fromhex('8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a')) bpriv = bytes.fromhex('5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb') bkey = X25519(bpriv) self.assertEqual(bkey.getpub(), bytes.fromhex('de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f')) ss = bytes.fromhex('4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742') self.assertEqual(akey.dh(bkey.getpub()), ss) self.assertEqual(bkey.dh(akey.getpub()), ss) def test_basic_ops(self): aprivkey = x25519_genkey() apubkey = x25519_base(aprivkey, 1) bprivkey = x25519_genkey() bpubkey = x25519_base(bprivkey, 1) self.assertNotEqual(aprivkey, bprivkey) self.assertNotEqual(apubkey, bpubkey) ra = x25519_wrap(apubkey, aprivkey, bpubkey, 1) rb = x25519_wrap(bpubkey, bprivkey, apubkey, 1) self.assertEqual(ra, rb)