diff --git a/comms.c b/comms.c index c38bc29..acb6073 100644 --- a/comms.c +++ b/comms.c @@ -30,6 +30,8 @@ static const size_t MAC_LEN = 8; static const size_t CHALLENGE_LEN = 16; static const uint8_t domain[] = "com.funkthat.lora.irrigation.shared.v0.0.1"; +static const uint8_t reqreset[] = "reqreset"; +static const uint8_t confirm[] = "confirm"; static int comms_pktbuf_equal(struct pktbuf a, struct pktbuf b); @@ -105,6 +107,7 @@ _comms_process_session(struct comms_state *cs, struct comms_session *sess, struc /* MAC check failed */ if (cnt == -1) { +badmsg: /* restore the previous state */ sess->cs_crypto = tmp; pbout->pktlen = 0; @@ -122,7 +125,9 @@ _comms_process_session(struct comms_state *cs, struct comms_session *sess, struc ret = 0; switch (sess->cs_state) { case COMMS_WAIT_REQUEST: - /* XXX - reqreset check */ + if (msglen != 24 || memcmp(reqreset, &buf[16], + sizeof reqreset - 1) != 0) + goto badmsg; bare_strobe_randomize(buf, CHALLENGE_LEN); ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, buf, @@ -135,7 +140,10 @@ _comms_process_session(struct comms_state *cs, struct comms_session *sess, struc break; case COMMS_WAIT_CONFIRM: - /* XXX - confirm check */ + if (msglen != 7 || memcmp(confirm, buf, + sizeof confirm - 1) != 0) + goto badmsg; + ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, CONFIRMED_STR, CONFIRMED_STR_LEN); ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); diff --git a/lora.py b/lora.py index 71b689d..6669519 100644 --- a/lora.py +++ b/lora.py @@ -562,10 +562,97 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): self.assertTrue(tsd.recvq.empty()) #_debprint('done') + @timeout(2) + async def test_ccode_badmsgs(self): + # Test to make sure that various bad messages in the + # handshake process are rejected even if the attacker + # has the correct key. This just keeps the protocol + # tight allowing for variations in the future. + + # seed the RNG + prngseed = b'abc123' + from ctypes import c_uint8 + lora_comms.strobe_seed_prng((c_uint8 * + len(prngseed))(*prngseed), len(prngseed)) + + # Create the state for testing + commstate = lora_comms.CommsState() + + # dummy callback + def procmsg(msg, outbuf): + pass + + cb = lora_comms.process_msgfunc_t(procmsg) + + # Generate shared key + shared_key = os.urandom(32) + + # Initialize everything + lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) + + # Create test fixture, only use it to init crypto state + tsd = SyncDatagram() + l = LORANode(tsd, shared=shared_key) + + # copy the crypto state + cstate = l.st.copy() + + # compose an incorrect init message + msg = os.urandom(16) + b'othre' + msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) + + out = lora_comms.comms_process_wrap(commstate, msg) + + self.assertFalse(out) + + # copy the crypto state + cstate = l.st.copy() + + # compose an incorrect init message + msg = os.urandom(16) + b' eqreset' + msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) + + out = lora_comms.comms_process_wrap(commstate, msg) + + self.assertFalse(out) + + # compose the correct init message + msg = os.urandom(16) + b'reqreset' + msg = l.st.send_enc(msg) + l.st.send_mac(l.MAC_LEN) + + out = lora_comms.comms_process_wrap(commstate, msg) + + l.st.recv_enc(out[:-l.MAC_LEN]) + l.st.recv_mac(out[-l.MAC_LEN:]) + + l.st.ratchet() + + # copy the crypto state + cstate = l.st.copy() + + # compose an incorrect confirmed message + msg = b'onfirm' + msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) + + out = lora_comms.comms_process_wrap(commstate, msg) + + self.assertFalse(out) + + # copy the crypto state + cstate = l.st.copy() + + # compose an incorrect confirmed message + msg = b' onfirm' + msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN) + + out = lora_comms.comms_process_wrap(commstate, msg) + + self.assertFalse(out) + @timeout(2) async def test_ccode(self): _self = self - from ctypes import pointer, sizeof, c_uint8 + from ctypes import c_uint8 # seed the RNG prngseed = b'abc123' @@ -675,7 +762,7 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): ''' _self = self - from ctypes import pointer, sizeof, c_uint8 + from ctypes import c_uint8 seq = AsyncSequence() diff --git a/lora_comms.py b/lora_comms.py index ca1d3f7..25ecc26 100644 --- a/lora_comms.py +++ b/lora_comms.py @@ -105,3 +105,17 @@ for func, ret, args in [ f.restype = ret f.argtypes = args locals()[func] = f + +def comms_process_wrap(state, input): + '''A wrapper around comms_process that converts the argument + into the buffer, and the returns the message as a bytes string. + ''' + + inpkt = make_pktbuf(input) + + outbytes = bytearray(64) + outbuf = make_pktbuf(outbytes) + + comms_process(state, inpkt, outbuf) + + return outbytes[:outbuf.pktlen]