An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

461 lines
12 KiB

  1. from noise.connection import NoiseConnection, Keypair
  2. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  3. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  4. from cryptography.hazmat.primitives import hashes
  5. from cryptography.hazmat.backends import default_backend
  6. from cryptography.hazmat.primitives.asymmetric import x448
  7. from cryptography.hazmat.primitives import serialization
  8. import asyncio
  9. import os.path
  10. import shutil
  11. import socket
  12. import tempfile
  13. import threading
  14. import unittest
  15. _backend = default_backend()
  16. def genkeypair():
  17. '''Generates a keypair, and returns a tuple of (public, private).
  18. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  19. key = x448.X448PrivateKey.generate()
  20. enc = serialization.Encoding.Raw
  21. pubformat = serialization.PublicFormat.Raw
  22. privformat = serialization.PrivateFormat.Raw
  23. encalgo = serialization.NoEncryption()
  24. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  25. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  26. return pub, priv
  27. def _makefut(obj):
  28. loop = asyncio.get_running_loop()
  29. fut = loop.create_future()
  30. fut.set_result(obj)
  31. return fut
  32. def _makeunix(path):
  33. '''Make a properly formed unix path socket string.'''
  34. return 'unix:%s' % path
  35. def _parsesockstr(sockstr):
  36. proto, rem = sockstr.split(':', 1)
  37. return proto, rem
  38. async def connectsockstr(sockstr):
  39. proto, rem = _parsesockstr(sockstr)
  40. reader, writer = await asyncio.open_unix_connection(rem)
  41. return reader, writer
  42. async def listensockstr(sockstr, cb):
  43. '''Wrapper for asyncio.start_x_server.
  44. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  45. If the proto has a default parameter, the value can be used
  46. directly, like: 'proto:value'. This is only allowed when the
  47. value can unambiguously be determined not to be a param.
  48. The characters that define 'param' must be all lower case ascii
  49. characters and may contain an underscore. The first character
  50. must not be and underscore.
  51. Supported protocols:
  52. unix:
  53. Default parameter is path.
  54. The path parameter specifies the path to the
  55. unix domain socket. The path MUST start w/ a
  56. slash if it is used as a default parameter.
  57. '''
  58. proto, rem = _parsesockstr(sockstr)
  59. server = await asyncio.start_unix_server(cb, path=rem)
  60. return server
  61. # !!python makemessagelengths.py
  62. _handshakelens = \
  63. [72, 72, 88]
  64. def _genciphfun(hash, ad):
  65. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  66. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  67. key = hkdf.derive(hash)
  68. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  69. backend=_backend)
  70. enctor = cipher.encryptor()
  71. def encfun(data):
  72. # Returns the two bytes for length
  73. val = len(data)
  74. encbytes = enctor.update(data[:16])
  75. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  76. return (val ^ mask).to_bytes(length=2, byteorder='big')
  77. def decfun(data):
  78. # takes off the data and returns the total
  79. # length
  80. val = int.from_bytes(data[:2], byteorder='big')
  81. encbytes = enctor.update(data[2:2 + 16])
  82. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  83. return val ^ mask
  84. return encfun, decfun
  85. async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
  86. rdr, wrr = await rdrwrr
  87. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  88. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  89. if pub_key is not None:
  90. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  91. pub_key)
  92. if mode == 'resp':
  93. proto.set_as_responder()
  94. proto.start_handshake()
  95. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  96. wrr.write(proto.write_message())
  97. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  98. elif mode == 'init':
  99. proto.set_as_initiator()
  100. proto.start_handshake()
  101. wrr.write(proto.write_message())
  102. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  103. wrr.write(proto.write_message())
  104. if not proto.handshake_finished: # pragma: no cover
  105. raise RuntimeError('failed to finish handshake')
  106. # generate the keys for lengths
  107. if mode == 'resp':
  108. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  109. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  110. elif mode == 'init':
  111. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  112. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  113. reader, writer = await ptpair
  114. async def decses():
  115. try:
  116. while True:
  117. try:
  118. msg = await rdr.readexactly(2 + 16)
  119. except asyncio.streams.IncompleteReadError:
  120. if rdr.at_eof():
  121. return 'dec'
  122. tlen = declenfun(msg)
  123. rmsg = await rdr.readexactly(tlen - 16)
  124. tmsg = msg[2:] + rmsg
  125. writer.write(proto.decrypt(tmsg))
  126. await writer.drain()
  127. #except:
  128. # import traceback
  129. # traceback.print_exc()
  130. # raise
  131. finally:
  132. writer.write_eof()
  133. async def encses():
  134. try:
  135. while True:
  136. # largest message
  137. ptmsg = await reader.read(65535 - 16)
  138. if not ptmsg:
  139. # eof
  140. return 'enc'
  141. encmsg = proto.encrypt(ptmsg)
  142. wrr.write(enclenfun(encmsg))
  143. wrr.write(encmsg)
  144. await wrr.drain()
  145. #except:
  146. # import traceback
  147. # traceback.print_exc()
  148. # raise
  149. finally:
  150. wrr.write_eof()
  151. return await asyncio.gather(decses(), encses())
  152. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  153. # Slightly modified to timeout
  154. def async_test(f):
  155. def wrapper(*args, **kwargs):
  156. coro = asyncio.coroutine(f)
  157. future = coro(*args, **kwargs)
  158. loop = asyncio.get_event_loop()
  159. # timeout after 2 seconds
  160. loop.run_until_complete(asyncio.wait_for(future, 2))
  161. return wrapper
  162. class Tests_misc(unittest.TestCase):
  163. def test_listensockstr(self):
  164. # XXX write test
  165. pass
  166. def test_genciphfun(self):
  167. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  168. msg = b'this is a bunch of data'
  169. tb = enc(msg)
  170. self.assertEqual(len(msg), dec(tb + msg))
  171. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  172. msg = os.urandom(i)
  173. self.assertEqual(len(msg), dec(enc(msg) + msg))
  174. def _asyncsockpair():
  175. '''Create a pair of sockets that are bound to each other.
  176. The function will return a tuple of two coroutine's, that
  177. each, when await'ed upon, will return the reader/writer pair.'''
  178. socka, sockb = socket.socketpair()
  179. return asyncio.open_connection(sock=socka), \
  180. asyncio.open_connection(sock=sockb)
  181. class Tests(unittest.TestCase):
  182. def setUp(self):
  183. # setup temporary directory
  184. d = os.path.realpath(tempfile.mkdtemp())
  185. self.basetempdir = d
  186. self.tempdir = os.path.join(d, 'subdir')
  187. os.mkdir(self.tempdir)
  188. # Generate key pairs
  189. self.server_key_pair = genkeypair()
  190. self.client_key_pair = genkeypair()
  191. def tearDown(self):
  192. shutil.rmtree(self.basetempdir)
  193. self.tempdir = None
  194. @async_test
  195. async def test_server(self):
  196. # Test is plumbed:
  197. # (reader, writer) -> servsock ->
  198. # (rdr, wrr) NoiseForward (reader, writer) ->
  199. # servptsock -> (ptsock[0], ptsock[1])
  200. # Path that the server will sit on
  201. servsockpath = os.path.join(self.tempdir, 'servsock')
  202. servarg = _makeunix(servsockpath)
  203. # Path that the server will send pt data to
  204. servptpath = os.path.join(self.tempdir, 'servptsock')
  205. # Setup pt target listener
  206. pttarg = _makeunix(servptpath)
  207. ptsock = []
  208. def ptsockaccept(reader, writer, ptsock=ptsock):
  209. ptsock.append((reader, writer))
  210. # Bind to pt listener
  211. lsock = await listensockstr(pttarg, ptsockaccept)
  212. nfs = []
  213. event = asyncio.Event()
  214. async def runnf(rdr, wrr):
  215. ptpair = asyncio.create_task(connectsockstr(pttarg))
  216. a = await NoiseForwarder('resp',
  217. _makefut((rdr, wrr)), ptpair,
  218. priv_key=self.server_key_pair[1])
  219. nfs.append(a)
  220. event.set()
  221. # Setup server listener
  222. ssock = await listensockstr(servarg, runnf)
  223. # Connect to server
  224. reader, writer = await connectsockstr(servarg)
  225. # Create client
  226. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  227. proto.set_as_initiator()
  228. # Setup required keys
  229. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  230. self.client_key_pair[1])
  231. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  232. self.server_key_pair[0])
  233. proto.start_handshake()
  234. # Send first message
  235. message = proto.write_message()
  236. self.assertEqual(len(message), _handshakelens[0])
  237. writer.write(message)
  238. # Get response
  239. respmsg = await reader.readexactly(_handshakelens[1])
  240. proto.read_message(respmsg)
  241. # Send final reply
  242. message = proto.write_message()
  243. writer.write(message)
  244. # Make sure handshake has completed
  245. self.assertTrue(proto.handshake_finished)
  246. # generate the keys for lengths
  247. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  248. b'toresp')
  249. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  250. b'toinit')
  251. # write a test message
  252. ptmsg = b'this is a test message that should be a little in length'
  253. encmsg = proto.encrypt(ptmsg)
  254. writer.write(enclenfun(encmsg))
  255. writer.write(encmsg)
  256. # XXX - how to sync?
  257. await asyncio.sleep(.1)
  258. ptreader, ptwriter = ptsock[0]
  259. # read the test message
  260. rptmsg = await ptreader.readexactly(len(ptmsg))
  261. self.assertEqual(rptmsg, ptmsg)
  262. # write a different message
  263. ptmsg = os.urandom(2843)
  264. encmsg = proto.encrypt(ptmsg)
  265. writer.write(enclenfun(encmsg))
  266. writer.write(encmsg)
  267. # XXX - how to sync?
  268. await asyncio.sleep(.1)
  269. # read the test message
  270. rptmsg = await ptreader.readexactly(len(ptmsg))
  271. self.assertEqual(rptmsg, ptmsg)
  272. # now try the other way
  273. ptmsg = os.urandom(912)
  274. ptwriter.write(ptmsg)
  275. # find out how much we need to read
  276. encmsg = await reader.readexactly(2 + 16)
  277. tlen = declenfun(encmsg)
  278. # read the rest of the message
  279. rencmsg = await reader.readexactly(tlen - 16)
  280. tmsg = encmsg[2:] + rencmsg
  281. rptmsg = proto.decrypt(tmsg)
  282. self.assertEqual(rptmsg, ptmsg)
  283. # shut down sending
  284. writer.write_eof()
  285. # so pt reader should be shut down
  286. self.assertEqual(b'', await ptreader.read(1))
  287. self.assertTrue(ptreader.at_eof())
  288. # shut down pt
  289. ptwriter.write_eof()
  290. # make sure the enc reader is eof
  291. self.assertEqual(b'', await reader.read(1))
  292. self.assertTrue(reader.at_eof())
  293. await event.wait()
  294. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  295. @async_test
  296. async def test_serverclient(self):
  297. # plumbing:
  298. #
  299. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  300. #
  301. ptcsockapair, ptcsockbpair = _asyncsockpair()
  302. ptcareader, ptcawriter = await ptcsockapair
  303. #ptcsockbpair passed directly
  304. clssockapair, clssockbpair = _asyncsockpair()
  305. #both passed directly
  306. ptssockapair, ptssockbpair = _asyncsockpair()
  307. #ptssockapair passed directly
  308. ptsbreader, ptsbwriter = await ptssockbpair
  309. clientnf = asyncio.create_task(NoiseForwarder('init',
  310. clssockapair, ptcsockbpair,
  311. priv_key=self.client_key_pair[1],
  312. pub_key=self.server_key_pair[0]))
  313. servnf = asyncio.create_task(NoiseForwarder('resp',
  314. clssockbpair, ptssockapair,
  315. priv_key=self.server_key_pair[1]))
  316. # send a message
  317. msga = os.urandom(183)
  318. ptcawriter.write(msga)
  319. # make sure we get the same message
  320. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  321. # send a second message
  322. msga = os.urandom(2834)
  323. ptcawriter.write(msga)
  324. # make sure we get the same message
  325. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  326. # send a message larger than the block size
  327. msga = os.urandom(103958)
  328. ptcawriter.write(msga)
  329. # make sure we get the same message
  330. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  331. # send a message the other direction
  332. msga = os.urandom(103958)
  333. ptsbwriter.write(msga)
  334. # make sure we get the same message
  335. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  336. # close down the pt writers, the rest should follow
  337. ptsbwriter.write_eof()
  338. ptcawriter.write_eof()
  339. # make sure they are closed, and there is no more data
  340. self.assertEqual(b'', await ptsbreader.read(1))
  341. self.assertTrue(ptsbreader.at_eof())
  342. self.assertEqual(b'', await ptcareader.read(1))
  343. self.assertTrue(ptcareader.at_eof())
  344. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  345. self.assertEqual([ 'dec', 'enc' ], await servnf)