| @@ -2,7 +2,7 @@ from twisted.trial import unittest | |||
| from twisted.test import proto_helpers | |||
| from noise.connection import NoiseConnection, Keypair | |||
| from twisted.internet.protocol import Factory | |||
| from twisted.internet import endpoints, reactor,defer | |||
| from twisted.internet import endpoints, reactor, defer, task | |||
| # XXX - shouldn't need to access the underlying primitives, but that's what | |||
| # noiseprotocol module requires. | |||
| from cryptography.hazmat.primitives.asymmetric import x448 | |||
| @@ -70,10 +70,9 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): | |||
| else: | |||
| r = self.noise.decrypt(data) | |||
| self.endpoint.write(r) | |||
| self.endpoint.transport.write(r) | |||
| def proxyConnected(self, endpoint): | |||
| print('pc') | |||
| self.endpoint = endpoint | |||
| self.transport.resumeProducing() | |||
| @@ -98,6 +97,8 @@ class TNServerTest(unittest.TestCase): | |||
| def setUp(self): | |||
| self.server_key_pair = genkeypair() | |||
| self.protos = [] | |||
| self.connectionmade = defer.Deferred() | |||
| class AccProtFactory(Factory): | |||
| protocol = proto_helpers.AccumulatingProtocol | |||
| @@ -105,19 +106,21 @@ class TNServerTest(unittest.TestCase): | |||
| self.__tc = tc | |||
| Factory.__init__(self) | |||
| protocolConnectionMade = self.connectionmade | |||
| def buildProtocol(self, addr): | |||
| r = Factory.buildProtocol(addr) | |||
| self.__tc.append(r) | |||
| r = Factory.buildProtocol(self, addr) | |||
| self.__tc.protos.append(r) | |||
| return r | |||
| for i in range(10000, 20000): | |||
| ep = endpoints.TCP4ServerEndpoint(reactor, i) | |||
| ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1') | |||
| try: | |||
| lpobj = yield ep.listen(AccProtFactory(self)) | |||
| except Exception: | |||
| except Exception: # pragma: no cover | |||
| continue | |||
| break | |||
| else: | |||
| else: # pragma: no cover | |||
| raise RuntimeError('all ports occupied') | |||
| self.testserv = ep | |||
| @@ -133,8 +136,8 @@ class TNServerTest(unittest.TestCase): | |||
| def tearDown(self): | |||
| self.listenportobj.stopListening() | |||
| @mock.patch('twisted.internet.endpoints.clientFromString') | |||
| def test_testprotocol(self, cfs): | |||
| @defer.inlineCallbacks | |||
| def test_testprotocol(self): | |||
| # Create client | |||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||
| proto.set_as_initiator() | |||
| @@ -157,9 +160,6 @@ class TNServerTest(unittest.TestCase): | |||
| # And process it | |||
| proto.read_message(resp) | |||
| clientconnection = defer.Deferred() | |||
| cfs().connect.return_value = clientconnection | |||
| # Send second message | |||
| message = proto.write_message() | |||
| self.proto.dataReceived(message) | |||
| @@ -171,21 +171,13 @@ class TNServerTest(unittest.TestCase): | |||
| # connection, otherwise no place to write the data | |||
| self.assertEqual(self.tr.producerState, 'paused') | |||
| # Make sure that clientFromString is called properly | |||
| cfs.assert_called_with(reactor, self.endpoint) | |||
| # And that it was connect'ed | |||
| cfs().connect.assert_called() | |||
| # Wait for the connection to be made | |||
| d = yield self.connectionmade | |||
| # and that ClientProxyFactory was called properly | |||
| args = cfs().connect.call_args.args | |||
| self.assertIsInstance(args[0], ClientProxyFactory) | |||
| self.assertIs(args[0].noiseproto, self.proto) | |||
| d = yield task.deferLater(reactor, .1, bool, 1) | |||
| # Simulate that a connection has happened | |||
| remoteend = proto_helpers.StringTransport() | |||
| remoteproto = args[0].buildProtocol(None) | |||
| remoteproto.makeConnection(remoteend) | |||
| # How to make this ready? | |||
| self.assertEqual(self.tr.producerState, 'producing') | |||
| # Encrypt the message | |||
| ptmsg = b'this is a test message' | |||
| @@ -194,4 +186,9 @@ class TNServerTest(unittest.TestCase): | |||
| # Feed it into the protocol | |||
| self.proto.dataReceived(encmsg) | |||
| self.assertEqual(remoteend.value(), ptmsg) | |||
| d = yield task.deferLater(reactor, .1, bool, 1) | |||
| clientend = self.protos[0] | |||
| self.assertEqual(clientend.data, ptmsg) | |||
| clientend.transport.loseConnection() | |||