An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

269 lines
8.7 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. __author__ = 'John-Mark Gurney'
  16. __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.'
  17. __license__ = '2-clause BSD license'
  18. # Copyright 2019 John-Mark Gurney.
  19. # All rights reserved.
  20. #
  21. # Redistribution and use in source and binary forms, with or without
  22. # modification, are permitted provided that the following conditions
  23. # are met:
  24. # 1. Redistributions of source code must retain the above copyright
  25. # notice, this list of conditions and the following disclaimer.
  26. # 2. Redistributions in binary form must reproduce the above copyright
  27. # notice, this list of conditions and the following disclaimer in the
  28. # documentation and/or other materials provided with the distribution.
  29. #
  30. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  31. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  32. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  33. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  34. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  35. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  36. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  37. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  38. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  39. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  40. # SUCH DAMAGE.
  41. # Notes:
  42. # Using XK, so that the connecting party's identity is hidden and that the
  43. # server's party's key is known.
  44. #
  45. # Noise packets are 16 bytes + length of data
  46. #
  47. # Proposed method to hide message lengths:
  48. # Immediately after handshake completes, each side generates and sends
  49. # an n byte key that will be used for encrypting (algo tbd) their own
  50. # byte counts. The length field will be encrypted via
  51. # E(pktnum, key) XOR 2 byte length.
  52. #
  53. # Note that authenticating the message length is NOT needed. This is
  54. # because the noise message blocks themselves are authenticated. The
  55. # worse that could happen is that a larger read (64k) is done, and then
  56. # the connection aborts because of decryption failure.
  57. #
  58. def genkeypair():
  59. '''Generates a keypair, and returns a tuple of (public, private).
  60. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  61. key = x448.X448PrivateKey.generate()
  62. enc = serialization.Encoding.Raw
  63. pubformat = serialization.PublicFormat.Raw
  64. privformat = serialization.PrivateFormat.Raw
  65. encalgo = serialization.NoEncryption()
  66. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  67. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  68. return pub, priv
  69. class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
  70. '''This class acts as a Noise Protocol responder. The factory that
  71. creates this Protocol is required to have the properties server_key
  72. and endpoint.
  73. The server_key propery is the key for the server that the clients are
  74. required to have (due to Noise XK protocol used) to authenticate the
  75. server.
  76. The endpoint property contains the endpoint as a string that will be
  77. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  78. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  79. for information on how to use this property.'''
  80. def connectionMade(self):
  81. # Initialize Noise
  82. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  83. self.noise = noise
  84. noise.set_as_responder()
  85. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  86. # Start Handshake
  87. noise.start_handshake()
  88. def encData(self, data):
  89. self.transport.write(self.noise.encrypt(data))
  90. def dataReceived(self, data):
  91. if not self.noise.handshake_finished:
  92. self.noise.read_message(data)
  93. if not self.noise.handshake_finished:
  94. self.transport.write(self.noise.write_message())
  95. if self.noise.handshake_finished:
  96. self.transport.pauseProducing()
  97. # start the connection to the endpoint
  98. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  99. epdef = ep.connect(ClientProxyFactory(self))
  100. epdef.addCallback(self.proxyConnected)
  101. else:
  102. r = self.noise.decrypt(data)
  103. self.endpoint.transport.write(r)
  104. def proxyConnected(self, endpoint):
  105. self.endpoint = endpoint
  106. self.transport.resumeProducing()
  107. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  108. def dataReceived(self, data):
  109. self.factory.noiseproto.encData(data)
  110. class ClientProxyFactory(Factory):
  111. protocol = ClientProxyProtocol
  112. def __init__(self, noiseproto):
  113. self.noiseproto = noiseproto
  114. class TwistedNoiseServerFactory(Factory):
  115. protocol = TwistedNoiseServerProtocol
  116. def __init__(self, server_key, endpoint):
  117. self.server_key = server_key
  118. self.endpoint = endpoint
  119. class TNServerTest(unittest.TestCase):
  120. @defer.inlineCallbacks
  121. def setUp(self):
  122. d = os.path.realpath(tempfile.mkdtemp())
  123. self.basetempdir = d
  124. self.tempdir = os.path.join(d, 'subdir')
  125. os.mkdir(self.tempdir)
  126. self.server_key_pair = genkeypair()
  127. self.protos = []
  128. self.connectionmade = defer.Deferred()
  129. class AccProtFactory(Factory):
  130. protocol = proto_helpers.AccumulatingProtocol
  131. def __init__(self, tc):
  132. self.__tc = tc
  133. Factory.__init__(self)
  134. protocolConnectionMade = self.connectionmade
  135. def buildProtocol(self, addr):
  136. r = Factory.buildProtocol(self, addr)
  137. self.__tc.protos.append(r)
  138. return r
  139. sockpath = os.path.join(self.tempdir, 'clientsock')
  140. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  141. lpobj = yield ep.listen(AccProtFactory(self))
  142. self.testserv = ep
  143. self.listenportobj = lpobj
  144. self.endpoint = 'unix:path=%s' % sockpath
  145. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  146. self.proto = factory.buildProtocol(None)
  147. self.tr = proto_helpers.StringTransport()
  148. self.proto.makeConnection(self.tr)
  149. self.client_key_pair = genkeypair()
  150. def tearDown(self):
  151. self.listenportobj.stopListening()
  152. shutil.rmtree(self.basetempdir)
  153. self.tempdir = None
  154. @defer.inlineCallbacks
  155. def test_testprotocol(self):
  156. #
  157. # How this test is plumbed:
  158. #
  159. # proto (NoiseConnection) -> self.tr (StringTransport) ->
  160. # self.proto (TwistedNoiseServerProtocol) ->
  161. # self.proto.endpoint (ClientProxyProtocol) -> unix sock ->
  162. # self.protos[0] (AccumulatingProtocol)
  163. #
  164. # Create client
  165. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  166. proto.set_as_initiator()
  167. # Setup required keys
  168. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  169. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  170. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  171. proto.start_handshake()
  172. # Send first message
  173. message = proto.write_message()
  174. self.proto.dataReceived(message)
  175. # Get response
  176. resp = self.tr.value()
  177. self.tr.clear()
  178. # And process it
  179. proto.read_message(resp)
  180. # Send second message
  181. message = proto.write_message()
  182. self.proto.dataReceived(message)
  183. # assert handshake finished
  184. self.assertTrue(proto.handshake_finished)
  185. # Make sure incoming data is paused till we establish client
  186. # connection, otherwise no place to write the data
  187. self.assertEqual(self.tr.producerState, 'paused')
  188. # Wait for the connection to be made
  189. d = yield self.connectionmade
  190. d = yield task.deferLater(reactor, .1, bool, 1)
  191. # How to make this ready?
  192. self.assertEqual(self.tr.producerState, 'producing')
  193. # Encrypt the message
  194. ptmsg = b'this is a test message'
  195. encmsg = proto.encrypt(ptmsg)
  196. # Feed it into the protocol
  197. self.proto.dataReceived(encmsg)
  198. # wait to pass it through
  199. d = yield task.deferLater(reactor, .1, bool, 1)
  200. # fetch remote end out
  201. clientend = self.protos[0]
  202. self.assertEqual(clientend.data, ptmsg)
  203. # send a message the other direction
  204. rptmsg = b'this is a different test message going the other way'
  205. clientend.transport.write(rptmsg)
  206. # wait to pass it through
  207. d = yield task.deferLater(reactor, .1, bool, 1)
  208. # receive it and decrypt it
  209. resp = self.tr.value()
  210. self.assertEqual(proto.decrypt(resp), rptmsg)
  211. # clean up connection
  212. clientend.transport.loseConnection()