| @@ -317,6 +317,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): | |||
| rmsg = await rdr.readexactly(tlen - 16) | |||
| tmsg = msg[2:] + rmsg | |||
| rpv = proto.decrypt(tmsg) | |||
| rempv = int.from_bytes(rpv, byteorder='big') | |||
| if rempv != protocol_version: | |||
| raise RuntimeError('unsupported protovol version received: %d' % | |||
| rempv) | |||
| async def decses(): | |||
| try: | |||
| @@ -1125,6 +1130,116 @@ class TestNoiseFowarder(unittest.TestCase): | |||
| await lsock.wait_closed() | |||
| await ssock.wait_closed() | |||
| @async_test | |||
| async def test_protocolversionmismatch(self): | |||
| # make sure that if we send a future version, that we | |||
| # still get a protocol version, and that the connection | |||
| # is closed w/o establishing a connection to the remote | |||
| # side | |||
| # Test is plumbed: | |||
| # (reader, writer) -> servsock -> | |||
| # (rdr, wrr) NoiseForward (reader, writer) -> | |||
| # servptsock -> (ptsock[0], ptsock[1]) | |||
| # Path that the server will sit on | |||
| servsockpath = os.path.join(self.tempdir, 'servsock') | |||
| servarg = _makeunix(servsockpath) | |||
| # Path that the server will send pt data to | |||
| servptpath = os.path.join(self.tempdir, 'servptsock') | |||
| # Setup pt target listener | |||
| pttarg = _makeunix(servptpath) | |||
| ptsock = [] | |||
| ptsockevent = asyncio.Event() | |||
| def ptsockaccept(reader, writer, ptsock=ptsock): | |||
| ptsock.append((reader, writer)) | |||
| ptsockevent.set() | |||
| # Bind to pt listener | |||
| lsock = await listensockstr(pttarg, ptsockaccept) | |||
| nfs = [] | |||
| event = asyncio.Event() | |||
| async def runnf(rdr, wrr): | |||
| ptpairfun = asyncio.create_task(connectsockstr(pttarg)) | |||
| try: | |||
| a = await NoiseForwarder('resp', | |||
| _makefut((rdr, wrr)), lambda x: ptpairfun, | |||
| priv_key=self.server_key_pair[1]) | |||
| except RuntimeError as e: | |||
| nfs.append(e) | |||
| event.set() | |||
| return | |||
| nfs.append(a) | |||
| event.set() | |||
| # Setup server listener | |||
| ssock = await listensockstr(servarg, runnf) | |||
| # Connect to server | |||
| reader, writer = await connectsockstr(servarg) | |||
| # Create client | |||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||
| proto.set_as_initiator() | |||
| # Setup required keys | |||
| proto.set_keypair_from_private_bytes(Keypair.STATIC, | |||
| self.client_key_pair[1]) | |||
| proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, | |||
| self.server_key_pair[0]) | |||
| proto.start_handshake() | |||
| # Send first message | |||
| message = proto.write_message() | |||
| self.assertEqual(len(message), _handshakelens[0]) | |||
| writer.write(message) | |||
| # Get response | |||
| respmsg = await reader.readexactly(_handshakelens[1]) | |||
| proto.read_message(respmsg) | |||
| # Send final reply | |||
| message = proto.write_message() | |||
| writer.write(message) | |||
| # Make sure handshake has completed | |||
| self.assertTrue(proto.handshake_finished) | |||
| # generate the keys for lengths | |||
| enclenfun, _ = _genciphfun(proto.get_handshake_hash(), | |||
| b'toresp') | |||
| _, declenfun = _genciphfun(proto.get_handshake_hash(), | |||
| b'toinit') | |||
| pversion = 1 | |||
| # Send the protocol version string first | |||
| encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) | |||
| writer.write(enclenfun(encmsg)) | |||
| writer.write(encmsg) | |||
| # Read the peer's protocol version | |||
| # find out how much we need to read | |||
| encmsg = await reader.readexactly(2 + 16) | |||
| tlen = declenfun(encmsg) | |||
| # read the rest of the message | |||
| rencmsg = await reader.readexactly(tlen - 16) | |||
| tmsg = encmsg[2:] + rencmsg | |||
| rptmsg = proto.decrypt(tmsg) | |||
| self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0) | |||
| await event.wait() | |||
| self.assertIsInstance(nfs[0], RuntimeError) | |||
| @async_test | |||
| async def test_serverclient(self): | |||
| # plumbing: | |||