| @@ -60,6 +60,11 @@ __license__ = '2-clause BSD license' | |||
| # the connection aborts because of decryption failure. | |||
| # | |||
| def _makeunix(path): | |||
| '''Make a properly formed unix path socket string.''' | |||
| return 'unix:%s' % path | |||
| def genkeypair(): | |||
| '''Generates a keypair, and returns a tuple of (public, private). | |||
| They are encoded as raw bytes, and sutible for use w/ Noise.''' | |||
| @@ -146,18 +151,18 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | |||
| # start the connection to the endpoint | |||
| ep = endpoints.clientFromString(reactor, self.factory.endpoint) | |||
| epdef = ep.connect(ClientProxyFactory(self)) | |||
| epdef = ep.connect(ServerPTProxyFactory(self)) | |||
| epdef.addCallback(self.plaintextConnected) | |||
| class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | |||
| mode = 'init' | |||
| class ClientProxyProtocol(twisted.internet.protocol.Protocol): | |||
| class ServerPTProxyProtocol(twisted.internet.protocol.Protocol): | |||
| def dataReceived(self, data): | |||
| self.factory.noiseproto.encData(data) | |||
| class ClientProxyFactory(Factory): | |||
| protocol = ClientProxyProtocol | |||
| class ServerPTProxyFactory(Factory): | |||
| protocol = ServerPTProxyProtocol | |||
| def __init__(self, noiseproto): | |||
| self.noiseproto = noiseproto | |||
| @@ -172,12 +177,17 @@ class TwistedNoiseServerFactory(Factory): | |||
| class TNServerTest(unittest.TestCase): | |||
| @defer.inlineCallbacks | |||
| def setUp(self): | |||
| # setup temporary directory | |||
| d = os.path.realpath(tempfile.mkdtemp()) | |||
| self.basetempdir = d | |||
| self.tempdir = os.path.join(d, 'subdir') | |||
| os.mkdir(self.tempdir) | |||
| # Generate key pairs | |||
| self.server_key_pair = genkeypair() | |||
| self.client_key_pair = genkeypair() | |||
| # Server's PT client will be here | |||
| self.protos = [] | |||
| self.connectionmade = defer.Deferred() | |||
| @@ -195,23 +205,21 @@ class TNServerTest(unittest.TestCase): | |||
| self.__tc.protos.append(r) | |||
| return r | |||
| sockpath = os.path.join(self.tempdir, 'clientsock') | |||
| # Setup PT client endpoint | |||
| sockpath = os.path.join(self.tempdir, 'servptsock') | |||
| ep = endpoints.UNIXServerEndpoint(reactor, sockpath) | |||
| lpobj = yield ep.listen(AccProtFactory(self)) | |||
| self.testserv = ep | |||
| self.listenportobj = lpobj | |||
| self.endpoint = 'unix:path=%s' % sockpath | |||
| self.endpoint = _makeunix(sockpath) | |||
| factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) | |||
| self.proto = factory.buildProtocol(None) | |||
| self.tr = proto_helpers.StringTransport() | |||
| self.proto.makeConnection(self.tr) | |||
| self.client_key_pair = genkeypair() | |||
| # Setup server, and configure where to connect to. | |||
| self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) | |||
| @defer.inlineCallbacks | |||
| def tearDown(self): | |||
| self.listenportobj.stopListening() | |||
| d = yield self.listenportobj.stopListening() | |||
| shutil.rmtree(self.basetempdir) | |||
| self.tempdir = None | |||
| @@ -223,9 +231,16 @@ class TNServerTest(unittest.TestCase): | |||
| # | |||
| # proto (NoiseConnection) -> self.tr (StringTransport) -> | |||
| # self.proto (TwistedNoiseServerProtocol) -> | |||
| # self.proto.endpoint (ClientProxyProtocol) -> unix sock -> | |||
| # self.proto.endpoint (ServerPTProxyProtocol) -> unix sock -> | |||
| # self.protos[0] (AccumulatingProtocol) | |||
| # | |||
| # Generate a server protocol, and bind it to a string | |||
| # transport for testing | |||
| self.proto = self.servfactory.buildProtocol(None) | |||
| self.tr = proto_helpers.StringTransport() | |||
| self.proto.makeConnection(self.tr) | |||
| # Create client | |||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||
| proto.set_as_initiator() | |||
| @@ -274,6 +289,7 @@ class TNServerTest(unittest.TestCase): | |||
| # Feed it into the protocol | |||
| self.proto.dataReceived(encmsg) | |||
| # XXX - fix | |||
| # wait to pass it through | |||
| d = yield task.deferLater(reactor, .1, bool, 1) | |||
| @@ -285,6 +301,7 @@ class TNServerTest(unittest.TestCase): | |||
| rptmsg = b'this is a different test message going the other way' | |||
| clientend.transport.write(rptmsg) | |||
| # XXX - fix | |||
| # wait to pass it through | |||
| d = yield task.deferLater(reactor, .1, bool, 1) | |||
| @@ -294,3 +311,16 @@ class TNServerTest(unittest.TestCase): | |||
| # clean up connection | |||
| clientend.transport.loseConnection() | |||
| @defer.inlineCallbacks | |||
| def test_clientserver(self): | |||
| # Path that the client "listener" sits on. | |||
| cptsockpath = os.path.join(self.tempdir, 'clientptsock') | |||
| # Path that the server sits on | |||
| servsockpath = os.path.join(self.tempdir, 'servsock') | |||
| servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) | |||
| servlpobj = yield servep.listen(self.servfactory) | |||
| d = yield servlpobj.stopListening() | |||