| @@ -60,6 +60,11 @@ __license__ = '2-clause BSD license' | |||||
| # the connection aborts because of decryption failure. | # the connection aborts because of decryption failure. | ||||
| # | # | ||||
| def _makeunix(path): | |||||
| '''Make a properly formed unix path socket string.''' | |||||
| return 'unix:%s' % path | |||||
| def genkeypair(): | def genkeypair(): | ||||
| '''Generates a keypair, and returns a tuple of (public, private). | '''Generates a keypair, and returns a tuple of (public, private). | ||||
| They are encoded as raw bytes, and sutible for use w/ Noise.''' | They are encoded as raw bytes, and sutible for use w/ Noise.''' | ||||
| @@ -146,18 +151,18 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | |||||
| # start the connection to the endpoint | # start the connection to the endpoint | ||||
| ep = endpoints.clientFromString(reactor, self.factory.endpoint) | ep = endpoints.clientFromString(reactor, self.factory.endpoint) | ||||
| epdef = ep.connect(ClientProxyFactory(self)) | |||||
| epdef = ep.connect(ServerPTProxyFactory(self)) | |||||
| epdef.addCallback(self.plaintextConnected) | epdef.addCallback(self.plaintextConnected) | ||||
| class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | ||||
| mode = 'init' | mode = 'init' | ||||
| class ClientProxyProtocol(twisted.internet.protocol.Protocol): | |||||
| class ServerPTProxyProtocol(twisted.internet.protocol.Protocol): | |||||
| def dataReceived(self, data): | def dataReceived(self, data): | ||||
| self.factory.noiseproto.encData(data) | self.factory.noiseproto.encData(data) | ||||
| class ClientProxyFactory(Factory): | |||||
| protocol = ClientProxyProtocol | |||||
| class ServerPTProxyFactory(Factory): | |||||
| protocol = ServerPTProxyProtocol | |||||
| def __init__(self, noiseproto): | def __init__(self, noiseproto): | ||||
| self.noiseproto = noiseproto | self.noiseproto = noiseproto | ||||
| @@ -172,12 +177,17 @@ class TwistedNoiseServerFactory(Factory): | |||||
| class TNServerTest(unittest.TestCase): | class TNServerTest(unittest.TestCase): | ||||
| @defer.inlineCallbacks | @defer.inlineCallbacks | ||||
| def setUp(self): | def setUp(self): | ||||
| # setup temporary directory | |||||
| d = os.path.realpath(tempfile.mkdtemp()) | d = os.path.realpath(tempfile.mkdtemp()) | ||||
| self.basetempdir = d | self.basetempdir = d | ||||
| self.tempdir = os.path.join(d, 'subdir') | self.tempdir = os.path.join(d, 'subdir') | ||||
| os.mkdir(self.tempdir) | os.mkdir(self.tempdir) | ||||
| # Generate key pairs | |||||
| self.server_key_pair = genkeypair() | self.server_key_pair = genkeypair() | ||||
| self.client_key_pair = genkeypair() | |||||
| # Server's PT client will be here | |||||
| self.protos = [] | self.protos = [] | ||||
| self.connectionmade = defer.Deferred() | self.connectionmade = defer.Deferred() | ||||
| @@ -195,23 +205,21 @@ class TNServerTest(unittest.TestCase): | |||||
| self.__tc.protos.append(r) | self.__tc.protos.append(r) | ||||
| return r | return r | ||||
| sockpath = os.path.join(self.tempdir, 'clientsock') | |||||
| # Setup PT client endpoint | |||||
| sockpath = os.path.join(self.tempdir, 'servptsock') | |||||
| ep = endpoints.UNIXServerEndpoint(reactor, sockpath) | ep = endpoints.UNIXServerEndpoint(reactor, sockpath) | ||||
| lpobj = yield ep.listen(AccProtFactory(self)) | lpobj = yield ep.listen(AccProtFactory(self)) | ||||
| self.testserv = ep | self.testserv = ep | ||||
| self.listenportobj = lpobj | self.listenportobj = lpobj | ||||
| self.endpoint = 'unix:path=%s' % sockpath | |||||
| self.endpoint = _makeunix(sockpath) | |||||
| factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) | |||||
| self.proto = factory.buildProtocol(None) | |||||
| self.tr = proto_helpers.StringTransport() | |||||
| self.proto.makeConnection(self.tr) | |||||
| self.client_key_pair = genkeypair() | |||||
| # Setup server, and configure where to connect to. | |||||
| self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) | |||||
| @defer.inlineCallbacks | |||||
| def tearDown(self): | def tearDown(self): | ||||
| self.listenportobj.stopListening() | |||||
| d = yield self.listenportobj.stopListening() | |||||
| shutil.rmtree(self.basetempdir) | shutil.rmtree(self.basetempdir) | ||||
| self.tempdir = None | self.tempdir = None | ||||
| @@ -223,9 +231,16 @@ class TNServerTest(unittest.TestCase): | |||||
| # | # | ||||
| # proto (NoiseConnection) -> self.tr (StringTransport) -> | # proto (NoiseConnection) -> self.tr (StringTransport) -> | ||||
| # self.proto (TwistedNoiseServerProtocol) -> | # self.proto (TwistedNoiseServerProtocol) -> | ||||
| # self.proto.endpoint (ClientProxyProtocol) -> unix sock -> | |||||
| # self.proto.endpoint (ServerPTProxyProtocol) -> unix sock -> | |||||
| # self.protos[0] (AccumulatingProtocol) | # self.protos[0] (AccumulatingProtocol) | ||||
| # | # | ||||
| # Generate a server protocol, and bind it to a string | |||||
| # transport for testing | |||||
| self.proto = self.servfactory.buildProtocol(None) | |||||
| self.tr = proto_helpers.StringTransport() | |||||
| self.proto.makeConnection(self.tr) | |||||
| # Create client | # Create client | ||||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | ||||
| proto.set_as_initiator() | proto.set_as_initiator() | ||||
| @@ -274,6 +289,7 @@ class TNServerTest(unittest.TestCase): | |||||
| # Feed it into the protocol | # Feed it into the protocol | ||||
| self.proto.dataReceived(encmsg) | self.proto.dataReceived(encmsg) | ||||
| # XXX - fix | |||||
| # wait to pass it through | # wait to pass it through | ||||
| d = yield task.deferLater(reactor, .1, bool, 1) | d = yield task.deferLater(reactor, .1, bool, 1) | ||||
| @@ -285,6 +301,7 @@ class TNServerTest(unittest.TestCase): | |||||
| rptmsg = b'this is a different test message going the other way' | rptmsg = b'this is a different test message going the other way' | ||||
| clientend.transport.write(rptmsg) | clientend.transport.write(rptmsg) | ||||
| # XXX - fix | |||||
| # wait to pass it through | # wait to pass it through | ||||
| d = yield task.deferLater(reactor, .1, bool, 1) | d = yield task.deferLater(reactor, .1, bool, 1) | ||||
| @@ -294,3 +311,16 @@ class TNServerTest(unittest.TestCase): | |||||
| # clean up connection | # clean up connection | ||||
| clientend.transport.loseConnection() | clientend.transport.loseConnection() | ||||
| @defer.inlineCallbacks | |||||
| def test_clientserver(self): | |||||
| # Path that the client "listener" sits on. | |||||
| cptsockpath = os.path.join(self.tempdir, 'clientptsock') | |||||
| # Path that the server sits on | |||||
| servsockpath = os.path.join(self.tempdir, 'servsock') | |||||
| servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) | |||||
| servlpobj = yield servep.listen(self.servfactory) | |||||
| d = yield servlpobj.stopListening() | |||||