An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

255 lines
8.1 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. __author__ = 'John-Mark Gurney'
  16. __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.'
  17. __license__ = '2-clause BSD license'
  18. # Copyright 2019 John-Mark Gurney.
  19. # All rights reserved.
  20. #
  21. # Redistribution and use in source and binary forms, with or without
  22. # modification, are permitted provided that the following conditions
  23. # are met:
  24. # 1. Redistributions of source code must retain the above copyright
  25. # notice, this list of conditions and the following disclaimer.
  26. # 2. Redistributions in binary form must reproduce the above copyright
  27. # notice, this list of conditions and the following disclaimer in the
  28. # documentation and/or other materials provided with the distribution.
  29. #
  30. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  31. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  32. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  33. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  34. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  35. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  36. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  37. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  38. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  39. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  40. # SUCH DAMAGE.
  41. # Notes:
  42. # Using XK, so that the connecting party's identity is hidden and that the
  43. # server's party's key is known.
  44. def genkeypair():
  45. '''Generates a keypair, and returns a tuple of (public, private).
  46. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  47. key = x448.X448PrivateKey.generate()
  48. enc = serialization.Encoding.Raw
  49. pubformat = serialization.PublicFormat.Raw
  50. privformat = serialization.PrivateFormat.Raw
  51. encalgo = serialization.NoEncryption()
  52. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  53. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  54. return pub, priv
  55. class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
  56. '''This class acts as a Noise Protocol responder. The factory that
  57. creates this Protocol is required to have the properties server_key
  58. and endpoint.
  59. The server_key propery is the key for the server that the clients are
  60. required to have (due to Noise XK protocol used) to authenticate the
  61. server.
  62. The endpoint property contains the endpoint as a string that will be
  63. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  64. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  65. for information on how to use this property.'''
  66. def connectionMade(self):
  67. # Initialize Noise
  68. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  69. self.noise = noise
  70. noise.set_as_responder()
  71. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  72. # Start Handshake
  73. noise.start_handshake()
  74. def encData(self, data):
  75. self.transport.write(self.noise.encrypt(data))
  76. def dataReceived(self, data):
  77. if not self.noise.handshake_finished:
  78. self.noise.read_message(data)
  79. if not self.noise.handshake_finished:
  80. self.transport.write(self.noise.write_message())
  81. if self.noise.handshake_finished:
  82. self.transport.pauseProducing()
  83. # start the connection to the endpoint
  84. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  85. epdef = ep.connect(ClientProxyFactory(self))
  86. epdef.addCallback(self.proxyConnected)
  87. else:
  88. r = self.noise.decrypt(data)
  89. self.endpoint.transport.write(r)
  90. def proxyConnected(self, endpoint):
  91. self.endpoint = endpoint
  92. self.transport.resumeProducing()
  93. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  94. def dataReceived(self, data):
  95. self.factory.noiseproto.encData(data)
  96. class ClientProxyFactory(Factory):
  97. protocol = ClientProxyProtocol
  98. def __init__(self, noiseproto):
  99. self.noiseproto = noiseproto
  100. class TwistedNoiseServerFactory(Factory):
  101. protocol = TwistedNoiseServerProtocol
  102. def __init__(self, server_key, endpoint):
  103. self.server_key = server_key
  104. self.endpoint = endpoint
  105. class TNServerTest(unittest.TestCase):
  106. @defer.inlineCallbacks
  107. def setUp(self):
  108. d = os.path.realpath(tempfile.mkdtemp())
  109. self.basetempdir = d
  110. self.tempdir = os.path.join(d, 'subdir')
  111. os.mkdir(self.tempdir)
  112. self.server_key_pair = genkeypair()
  113. self.protos = []
  114. self.connectionmade = defer.Deferred()
  115. class AccProtFactory(Factory):
  116. protocol = proto_helpers.AccumulatingProtocol
  117. def __init__(self, tc):
  118. self.__tc = tc
  119. Factory.__init__(self)
  120. protocolConnectionMade = self.connectionmade
  121. def buildProtocol(self, addr):
  122. r = Factory.buildProtocol(self, addr)
  123. self.__tc.protos.append(r)
  124. return r
  125. sockpath = os.path.join(self.tempdir, 'clientsock')
  126. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  127. lpobj = yield ep.listen(AccProtFactory(self))
  128. self.testserv = ep
  129. self.listenportobj = lpobj
  130. self.endpoint = 'unix:path=%s' % sockpath
  131. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  132. self.proto = factory.buildProtocol(None)
  133. self.tr = proto_helpers.StringTransport()
  134. self.proto.makeConnection(self.tr)
  135. self.client_key_pair = genkeypair()
  136. def tearDown(self):
  137. self.listenportobj.stopListening()
  138. shutil.rmtree(self.basetempdir)
  139. self.tempdir = None
  140. @defer.inlineCallbacks
  141. def test_testprotocol(self):
  142. #
  143. # How this test is plumbed:
  144. #
  145. # proto (NoiseConnection) -> self.tr (StringTransport) ->
  146. # self.proto (TwistedNoiseServerProtocol) ->
  147. # self.proto.endpoint (ClientProxyProtocol) -> unix sock ->
  148. # self.protos[0] (AccumulatingProtocol)
  149. #
  150. # Create client
  151. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  152. proto.set_as_initiator()
  153. # Setup required keys
  154. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  155. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  156. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  157. proto.start_handshake()
  158. # Send first message
  159. message = proto.write_message()
  160. self.proto.dataReceived(message)
  161. # Get response
  162. resp = self.tr.value()
  163. self.tr.clear()
  164. # And process it
  165. proto.read_message(resp)
  166. # Send second message
  167. message = proto.write_message()
  168. self.proto.dataReceived(message)
  169. # assert handshake finished
  170. self.assertTrue(proto.handshake_finished)
  171. # Make sure incoming data is paused till we establish client
  172. # connection, otherwise no place to write the data
  173. self.assertEqual(self.tr.producerState, 'paused')
  174. # Wait for the connection to be made
  175. d = yield self.connectionmade
  176. d = yield task.deferLater(reactor, .1, bool, 1)
  177. # How to make this ready?
  178. self.assertEqual(self.tr.producerState, 'producing')
  179. # Encrypt the message
  180. ptmsg = b'this is a test message'
  181. encmsg = proto.encrypt(ptmsg)
  182. # Feed it into the protocol
  183. self.proto.dataReceived(encmsg)
  184. # wait to pass it through
  185. d = yield task.deferLater(reactor, .1, bool, 1)
  186. # fetch remote end out
  187. clientend = self.protos[0]
  188. self.assertEqual(clientend.data, ptmsg)
  189. # send a message the other direction
  190. rptmsg = b'this is a different test message going the other way'
  191. clientend.transport.write(rptmsg)
  192. # wait to pass it through
  193. d = yield task.deferLater(reactor, .1, bool, 1)
  194. # receive it and decrypt it
  195. resp = self.tr.value()
  196. self.assertEqual(proto.decrypt(resp), rptmsg)
  197. # clean up connection
  198. clientend.transport.loseConnection()