An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1151 lines
31 KiB

  1. from cryptography.hazmat.backends import default_backend
  2. from cryptography.hazmat.primitives import hashes
  3. from cryptography.hazmat.primitives import serialization
  4. from cryptography.hazmat.primitives.asymmetric import x448
  5. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  6. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  7. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  8. from noise.connection import NoiseConnection, Keypair
  9. #import tracemalloc; tracemalloc.start(100)
  10. import argparse
  11. import asyncio
  12. import base64
  13. import os.path
  14. import shutil
  15. import socket
  16. import sys
  17. import tempfile
  18. import time
  19. import threading
  20. import unittest
  21. _backend = default_backend()
  22. def loadprivkey(fname):
  23. with open(fname, encoding='ascii') as fp:
  24. data = fp.read().encode('ascii')
  25. key = load_pem_private_key(data, password=None, backend=default_backend())
  26. return key
  27. def loadprivkeyraw(fname):
  28. key = loadprivkey(fname)
  29. enc = serialization.Encoding.Raw
  30. privformat = serialization.PrivateFormat.Raw
  31. encalgo = serialization.NoEncryption()
  32. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  33. def loadpubkeyraw(fname):
  34. with open(fname, encoding='ascii') as fp:
  35. lines = fp.readlines()
  36. # XXX
  37. #self.assertEqual(len(lines), 1)
  38. keytype, keyvalue = lines[0].split()
  39. if keytype != 'ntun-x448':
  40. raise RuntimeError
  41. return base64.urlsafe_b64decode(keyvalue)
  42. def genkeypair():
  43. '''Generates a keypair, and returns a tuple of (public, private).
  44. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  45. key = x448.X448PrivateKey.generate()
  46. enc = serialization.Encoding.Raw
  47. pubformat = serialization.PublicFormat.Raw
  48. privformat = serialization.PrivateFormat.Raw
  49. encalgo = serialization.NoEncryption()
  50. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  51. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  52. return pub, priv
  53. def _makefut(obj):
  54. loop = asyncio.get_running_loop()
  55. fut = loop.create_future()
  56. fut.set_result(obj)
  57. return fut
  58. def _makeunix(path):
  59. '''Make a properly formed unix path socket string.'''
  60. return 'unix:%s' % path
  61. _allowedparameters = {
  62. 'unix': {
  63. 'path': str,
  64. },
  65. 'tcp': {
  66. 'host': str,
  67. 'port': int,
  68. },
  69. }
  70. def _parsesockstr(sockstr):
  71. '''Parse a socket string to its parts. If there are no
  72. kwargs (no = after the colon), a dictionary w/ a single
  73. key of default will pass the string after the colon.
  74. default is a reserved keyword and MUST NOT be used.'''
  75. proto, rem = sockstr.split(':', 1)
  76. if '=' not in rem:
  77. if proto == 'unix' and rem[0] != '/':
  78. raise ValueError('bare path MUST start w/ a slash (/).')
  79. if proto == 'unix':
  80. args = { 'path': rem }
  81. else:
  82. args = dict(i.split('=', 1) for i in rem.split(','))
  83. try:
  84. allowed = _allowedparameters[proto]
  85. except KeyError:
  86. raise ValueError('unsupported proto: %s' % repr(proto))
  87. extrakeys = args.keys() - allowed.keys()
  88. if extrakeys:
  89. raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
  90. for i in args:
  91. args[i] = allowed[i](args[i])
  92. return proto, args
  93. async def connectsockstr(sockstr):
  94. '''Wrapper for asyncio.open_*_connection.'''
  95. proto, args = _parsesockstr(sockstr)
  96. if proto == 'unix':
  97. fun = asyncio.open_unix_connection
  98. elif proto == 'tcp':
  99. fun = asyncio.open_connection
  100. reader, writer = await fun(**args)
  101. return reader, writer
  102. async def listensockstr(sockstr, cb):
  103. '''Wrapper for asyncio.start_x_server.
  104. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  105. If the proto has a default parameter, the value can be used
  106. directly, like: 'proto:value'. This is only allowed when the
  107. value can unambiguously be determined not to be a param.
  108. The cb parameter is passed to asyncio's start_server or related
  109. calls. Per those docs, the cb parameter is calls or scheduled
  110. as a task when a client establishes a connection. It is called
  111. with two arguments, the reader and writer streams. For more
  112. information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
  113. The characters that define 'param' must be all lower case ascii
  114. characters and may contain an underscore. The first character
  115. must not be an underscore.
  116. Supported protocols:
  117. unix:
  118. Default parameter is path.
  119. The path parameter specifies the path to the
  120. unix domain socket. The path MUST start w/ a
  121. slash if it is used as a default parameter.
  122. '''
  123. proto, args = _parsesockstr(sockstr)
  124. if proto == 'unix':
  125. fun = asyncio.start_unix_server
  126. elif proto == 'tcp':
  127. fun = asyncio.start_server
  128. return await fun(cb, **args)
  129. # !!python makemessagelengths.py
  130. _handshakelens = \
  131. [72, 72, 88]
  132. def _genciphfun(hash, ad):
  133. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  134. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  135. key = hkdf.derive(hash)
  136. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  137. backend=_backend)
  138. enctor = cipher.encryptor()
  139. def encfun(data):
  140. # Returns the two bytes for length
  141. val = len(data)
  142. encbytes = enctor.update(data[:16])
  143. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  144. return (val ^ mask).to_bytes(length=2, byteorder='big')
  145. def decfun(data):
  146. # takes off the data and returns the total
  147. # length
  148. val = int.from_bytes(data[:2], byteorder='big')
  149. encbytes = enctor.update(data[2:2 + 16])
  150. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  151. return val ^ mask
  152. return encfun, decfun
  153. async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
  154. '''A function that forwards data between the plain text pair of
  155. streams to the encrypted session.
  156. The mode paramater must be one of 'init' or 'resp' for initiator
  157. and responder.
  158. The encrdrwrr is an await object that will return a tunle of the
  159. reader and writer streams for the encrypted side of the
  160. connection.
  161. The ptpairfun parameter is a function that will be passed the
  162. public key bytes for the remote client. This can be used to
  163. both validate that the correct client is connecting, and to
  164. pass back the correct plain text reader/writer objects that
  165. match the provided static key. The function must be an async
  166. function.
  167. In the case of the initiator, pub_key must be provided and will
  168. be used to authenticate the responder side of the connection.
  169. The priv_key parameter is used to authenticate this side of the
  170. session.
  171. Both priv_key and pub_key parameters must be 56 bytes. For example,
  172. the pair that is returned by genkeypair.
  173. '''
  174. rdr, wrr = await encrdrwrr
  175. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  176. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  177. if pub_key is not None:
  178. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  179. pub_key)
  180. if mode == 'resp':
  181. proto.set_as_responder()
  182. proto.start_handshake()
  183. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  184. wrr.write(proto.write_message())
  185. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  186. elif mode == 'init':
  187. proto.set_as_initiator()
  188. proto.start_handshake()
  189. wrr.write(proto.write_message())
  190. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  191. wrr.write(proto.write_message())
  192. if not proto.handshake_finished: # pragma: no cover
  193. raise RuntimeError('failed to finish handshake')
  194. try:
  195. reader, writer = await ptpairfun(getattr(proto.get_keypair(
  196. Keypair.REMOTE_STATIC), 'public_bytes', None))
  197. except:
  198. wrr.close()
  199. raise
  200. # generate the keys for lengths
  201. # XXX - get_handshake_hash is probably not the best option, but
  202. # this is only to obscure lengths, it is not required to be secure
  203. # as the underlying NoiseProtocol securely validates everything.
  204. # It is marginally useful as writing patterns likely expose the
  205. # true length. Adding padding could marginally help w/ this.
  206. if mode == 'resp':
  207. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  208. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  209. elif mode == 'init':
  210. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  211. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  212. async def decses():
  213. try:
  214. while True:
  215. try:
  216. msg = await rdr.readexactly(2 + 16)
  217. except asyncio.streams.IncompleteReadError:
  218. if rdr.at_eof():
  219. return 'dec'
  220. tlen = declenfun(msg)
  221. rmsg = await rdr.readexactly(tlen - 16)
  222. tmsg = msg[2:] + rmsg
  223. writer.write(proto.decrypt(tmsg))
  224. await writer.drain()
  225. #except:
  226. # import traceback
  227. # traceback.print_exc()
  228. # raise
  229. finally:
  230. try:
  231. writer.write_eof()
  232. except OSError as e:
  233. if e.errno != 57:
  234. raise
  235. async def encses():
  236. try:
  237. while True:
  238. # largest message
  239. ptmsg = await reader.read(65535 - 16)
  240. if not ptmsg:
  241. # eof
  242. return 'enc'
  243. encmsg = proto.encrypt(ptmsg)
  244. wrr.write(enclenfun(encmsg))
  245. wrr.write(encmsg)
  246. await wrr.drain()
  247. #except:
  248. # import traceback
  249. # traceback.print_exc()
  250. # raise
  251. finally:
  252. wrr.write_eof()
  253. res = await asyncio.gather(decses(), encses())
  254. await wrr.drain() # not sure if needed
  255. wrr.close()
  256. await writer.drain() # not sure if needed
  257. writer.close()
  258. return res
  259. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  260. # Slightly modified to timeout and to print trace back when canceled.
  261. # This makes it easier to figure out what "froze".
  262. def async_test(f):
  263. def wrapper(*args, **kwargs):
  264. async def tbcapture():
  265. try:
  266. return await f(*args, **kwargs)
  267. except asyncio.CancelledError as e:
  268. # if we are going to be cancelled, print out a tb
  269. import traceback
  270. traceback.print_exc()
  271. raise
  272. loop = asyncio.get_event_loop()
  273. # timeout after 4 seconds
  274. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  275. return wrapper
  276. class Tests_misc(unittest.TestCase):
  277. def setUp(self):
  278. # setup temporary directory
  279. d = os.path.realpath(tempfile.mkdtemp())
  280. self.basetempdir = d
  281. self.tempdir = os.path.join(d, 'subdir')
  282. os.mkdir(self.tempdir)
  283. os.chdir(self.tempdir)
  284. def tearDown(self):
  285. #print('td:', time.time())
  286. shutil.rmtree(self.basetempdir)
  287. self.tempdir = None
  288. def test_parsesockstr_bad(self):
  289. badstrs = [
  290. 'unix:ff',
  291. 'randomnocolon',
  292. 'unix:somethingelse=bogus',
  293. 'tcp:port=bogus',
  294. ]
  295. for i in badstrs:
  296. with self.assertRaises(ValueError,
  297. msg='Should have failed processing: %s' % repr(i)):
  298. _parsesockstr(i)
  299. def test_parsesockstr(self):
  300. results = {
  301. # Not all of these are valid when passed to a *sockstr
  302. # function
  303. 'unix:/apath': ('unix', { 'path': '/apath' }),
  304. 'unix:path=apath': ('unix', { 'path': 'apath' }),
  305. 'tcp:host=apath': ('tcp', { 'host': 'apath' }),
  306. 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
  307. 'port': 5 }),
  308. }
  309. for s, r in results.items():
  310. self.assertEqual(_parsesockstr(s), r)
  311. @async_test
  312. async def test_listensockstr_bad(self):
  313. with self.assertRaises(ValueError):
  314. ls = await listensockstr('bogus:some=arg', None)
  315. with self.assertRaises(ValueError):
  316. ls = await connectsockstr('bogus:some=arg')
  317. @async_test
  318. async def test_listenconnectsockstr(self):
  319. msgsent = b'this is a test message'
  320. msgrcv = b'testing message for receive'
  321. # That when a connection is received and receives and sends
  322. async def servconfhandle(rdr, wrr):
  323. msg = await rdr.readexactly(len(msgsent))
  324. self.assertEqual(msg, msgsent)
  325. #print(repr(wrr.get_extra_info('sockname')))
  326. wrr.write(msgrcv)
  327. await wrr.drain()
  328. wrr.close()
  329. return True
  330. # Test listensockstr
  331. for sstr, confun in [
  332. ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
  333. ('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
  334. ]:
  335. # that listensockstr will bind to the correct path, can call cb
  336. ls = await listensockstr(sstr, servconfhandle)
  337. # that we open a connection to the path
  338. rdr, wrr = await confun()
  339. # and send a message
  340. wrr.write(msgsent)
  341. # and receive the message
  342. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  343. self.assertEqual(rcv, msgrcv)
  344. wrr.close()
  345. # Now test that connectsockstr works similarly.
  346. rdr, wrr = await connectsockstr(sstr)
  347. # and send a message
  348. wrr.write(msgsent)
  349. # and receive the message
  350. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  351. self.assertEqual(rcv, msgrcv)
  352. wrr.close()
  353. ls.close()
  354. await ls.wait_closed()
  355. def test_genciphfun(self):
  356. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  357. msg = b'this is a bunch of data'
  358. tb = enc(msg)
  359. self.assertEqual(len(msg), dec(tb + msg))
  360. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  361. msg = os.urandom(i)
  362. self.assertEqual(len(msg), dec(enc(msg) + msg))
  363. def cmd_client(args):
  364. privkey = loadprivkeyraw(args.clientkey)
  365. pubkey = loadpubkeyraw(args.servkey)
  366. async def runnf(rdr, wrr):
  367. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  368. a = await NoiseForwarder('init',
  369. encpair, lambda x: _makefut((rdr, wrr)),
  370. priv_key=privkey, pub_key=pubkey)
  371. # Setup client listener
  372. ssock = listensockstr(args.clientlisten, runnf)
  373. loop = asyncio.get_event_loop()
  374. obj = loop.run_until_complete(ssock)
  375. loop.run_until_complete(obj.serve_forever())
  376. def cmd_server(args):
  377. privkey = loadprivkeyraw(args.servkey)
  378. pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
  379. async def runnf(rdr, wrr):
  380. async def checkclientfun(clientkey):
  381. if clientkey not in pubkeys:
  382. raise RuntimeError('invalid key provided')
  383. return await connectsockstr(args.servtarget)
  384. a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
  385. checkclientfun, priv_key=privkey)
  386. # Setup server listener
  387. ssock = listensockstr(args.servlisten, runnf)
  388. loop = asyncio.get_event_loop()
  389. obj = loop.run_until_complete(ssock)
  390. loop.run_until_complete(obj.serve_forever())
  391. def cmd_genkey(args):
  392. keypair = genkeypair()
  393. key = x448.X448PrivateKey.generate()
  394. # public key part
  395. enc = serialization.Encoding.Raw
  396. pubformat = serialization.PublicFormat.Raw
  397. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  398. try:
  399. fname = args.fname + '.pub'
  400. with open(fname, 'x', encoding='ascii') as fp:
  401. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  402. except FileExistsError:
  403. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  404. sys.exit(2)
  405. enc = serialization.Encoding.PEM
  406. format = serialization.PrivateFormat.PKCS8
  407. encalgo = serialization.NoEncryption()
  408. with open(args.fname, 'x', encoding='ascii') as fp:
  409. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  410. def main():
  411. parser = argparse.ArgumentParser()
  412. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  413. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  414. parser_gk.add_argument('fname', type=str, help='file name for the key')
  415. parser_gk.set_defaults(func=cmd_genkey)
  416. parser_serv = subparsers.add_parser('server', help='run a server')
  417. parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  418. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  419. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  420. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  421. parser_serv.set_defaults(func=cmd_server)
  422. parser_client = subparsers.add_parser('client', help='run a client')
  423. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  424. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  425. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  426. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  427. parser_client.set_defaults(func=cmd_client)
  428. args = parser.parse_args()
  429. try:
  430. fun = args.func
  431. except AttributeError:
  432. parser.print_usage()
  433. sys.exit(5)
  434. fun(args)
  435. if __name__ == '__main__': # pragma: no cover
  436. main()
  437. def _asyncsockpair():
  438. '''Create a pair of sockets that are bound to each other.
  439. The function will return a tuple of two coroutine's, that
  440. each, when await'ed upon, will return the reader/writer pair.'''
  441. socka, sockb = socket.socketpair()
  442. return asyncio.open_connection(sock=socka), \
  443. asyncio.open_connection(sock=sockb)
  444. async def _awaitfile(fname):
  445. while not os.path.exists(fname):
  446. await asyncio.sleep(.01)
  447. return True
  448. class TestMain(unittest.TestCase):
  449. def setUp(self):
  450. # setup temporary directory
  451. d = os.path.realpath(tempfile.mkdtemp())
  452. self.basetempdir = d
  453. self.tempdir = os.path.join(d, 'subdir')
  454. os.mkdir(self.tempdir)
  455. # Generate key pairs
  456. self.server_key_pair = genkeypair()
  457. self.client_key_pair = genkeypair()
  458. os.chdir(self.tempdir)
  459. def tearDown(self):
  460. #print('td:', time.time())
  461. shutil.rmtree(self.basetempdir)
  462. self.tempdir = None
  463. @async_test
  464. async def test_noargs(self):
  465. proc = await self.run_with_args()
  466. await proc.wait()
  467. # XXX - not checking error message
  468. # And that it exited w/ the correct code
  469. self.assertEqual(proc.returncode, 5)
  470. def run_with_args(self, *args, pipes=True):
  471. kwargs = {}
  472. if pipes:
  473. kwargs.update(dict(
  474. stdout=asyncio.subprocess.PIPE,
  475. stderr=asyncio.subprocess.PIPE))
  476. return asyncio.create_subprocess_exec(sys.executable,
  477. # XXX - figure out how to add coverage data on these runs
  478. #'-m', 'coverage', 'run', '-p',
  479. __file__, *args, **kwargs)
  480. async def genkey(self, name):
  481. proc = await self.run_with_args('genkey', name, pipes=False)
  482. await proc.wait()
  483. self.assertEqual(proc.returncode, 0)
  484. @async_test
  485. async def test_loadpubkey(self):
  486. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  487. await self.genkey(keypath)
  488. privkey = loadprivkey(keypath)
  489. enc = serialization.Encoding.Raw
  490. pubformat = serialization.PublicFormat.Raw
  491. pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
  492. pubkey = loadpubkeyraw(keypath + '.pub')
  493. self.assertEqual(pubkeybytes, pubkey)
  494. privrawkey = loadprivkeyraw(keypath)
  495. enc = serialization.Encoding.Raw
  496. privformat = serialization.PrivateFormat.Raw
  497. encalgo = serialization.NoEncryption()
  498. rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  499. self.assertEqual(rprivrawkey, privrawkey)
  500. @async_test
  501. async def test_clientkeymismatch(self):
  502. # make sure that if there's a client key mismatch, we
  503. # don't connect
  504. # Generate necessar keys
  505. servkeypath = os.path.join(self.tempdir, 'server_key')
  506. await self.genkey(servkeypath)
  507. clientkeypath = os.path.join(self.tempdir, 'client_key')
  508. await self.genkey(clientkeypath)
  509. badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
  510. await self.genkey(badclientkeypath)
  511. # forwards connectsion to this socket (created by client)
  512. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  513. ptclientstr = _makeunix(ptclientpath)
  514. # this is the socket server listen to
  515. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  516. incservstr = _makeunix(incservpath)
  517. # to this socket, opened by server
  518. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  519. servtargstr = _makeunix(servtargpath)
  520. # Setup server target listener
  521. ptsockevent = asyncio.Event()
  522. # Bind to pt listener
  523. lsock = await listensockstr(servtargstr, None)
  524. # Startup the server
  525. server = await self.run_with_args('server',
  526. '-c', clientkeypath + '.pub',
  527. servkeypath, incservstr, servtargstr)
  528. # Startup the client with the "bad" key
  529. client = await self.run_with_args('client',
  530. badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr)
  531. # wait for server target to be created
  532. await _awaitfile(servtargpath)
  533. # wait for server to start
  534. await _awaitfile(incservpath)
  535. # wait for client to start
  536. await _awaitfile(ptclientpath)
  537. # Connect to the client
  538. reader, writer = await connectsockstr(ptclientstr)
  539. # XXX - this might not be the best test.
  540. with self.assertRaises(asyncio.futures.TimeoutError):
  541. # make sure that we don't get the conenction
  542. await asyncio.wait_for(ptsockevent.wait(), .5)
  543. writer.close()
  544. # Make sure that when the server is terminated
  545. server.terminate()
  546. # that it's stderr
  547. stdout, stderr = await server.communicate()
  548. #print('s:', repr((stdout, stderr)))
  549. # doesn't have an exceptions never retrieved
  550. # even the example echo server has this same leak
  551. #self.assertNotIn(b'Task exception was never retrieved', stderr)
  552. lsock.close()
  553. await lsock.wait_closed()
  554. @async_test
  555. async def test_end2end(self):
  556. # Generate necessar keys
  557. servkeypath = os.path.join(self.tempdir, 'server_key')
  558. await self.genkey(servkeypath)
  559. clientkeypath = os.path.join(self.tempdir, 'client_key')
  560. await self.genkey(clientkeypath)
  561. # forwards connectsion to this socket (created by client)
  562. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  563. ptclientstr = _makeunix(ptclientpath)
  564. # this is the socket server listen to
  565. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  566. incservstr = _makeunix(incservpath)
  567. # to this socket, opened by server
  568. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  569. servtargstr = _makeunix(servtargpath)
  570. # Setup server target listener
  571. ptsock = []
  572. ptsockevent = asyncio.Event()
  573. def ptsockaccept(reader, writer, ptsock=ptsock):
  574. ptsock.append((reader, writer))
  575. ptsockevent.set()
  576. # Bind to pt listener
  577. lsock = await listensockstr(servtargstr, ptsockaccept)
  578. # Startup the server
  579. server = await self.run_with_args('server',
  580. '-c', clientkeypath + '.pub',
  581. servkeypath, incservstr, servtargstr,
  582. pipes=False)
  583. # Startup the client
  584. client = await self.run_with_args('client',
  585. clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
  586. pipes=False)
  587. # wait for server target to be created
  588. await _awaitfile(servtargpath)
  589. # wait for server to start
  590. await _awaitfile(incservpath)
  591. # wait for client to start
  592. await _awaitfile(ptclientpath)
  593. # Connect to the client
  594. reader, writer = await connectsockstr(ptclientstr)
  595. # send a message
  596. ptmsg = b'this is a message for testing'
  597. writer.write(ptmsg)
  598. # make sure that we got the conenction
  599. await ptsockevent.wait()
  600. # get the connection
  601. endrdr, endwrr = ptsock[0]
  602. # make sure we can read back what we sent
  603. self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
  604. # test some additional messages
  605. for i in [ 129, 1287, 28792, 129872 ]:
  606. # in on direction
  607. msg = os.urandom(i)
  608. writer.write(msg)
  609. self.assertEqual(msg, await endrdr.readexactly(len(msg)))
  610. # and the other
  611. endwrr.write(msg)
  612. self.assertEqual(msg, await reader.readexactly(len(msg)))
  613. writer.close()
  614. endwrr.close()
  615. lsock.close()
  616. await lsock.wait_closed()
  617. @async_test
  618. async def test_genkey(self):
  619. # that it can generate a key
  620. proc = await self.run_with_args('genkey', 'somefile')
  621. await proc.wait()
  622. #print(await proc.communicate())
  623. self.assertEqual(proc.returncode, 0)
  624. with open('somefile.pub', encoding='ascii') as fp:
  625. lines = fp.readlines()
  626. self.assertEqual(len(lines), 1)
  627. keytype, keyvalue = lines[0].split()
  628. self.assertEqual(keytype, 'ntun-x448')
  629. key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
  630. key = loadprivkey('somefile')
  631. self.assertIsInstance(key, x448.X448PrivateKey)
  632. # that a second call fails
  633. proc = await self.run_with_args('genkey', 'somefile')
  634. await proc.wait()
  635. stdoutdata, stderrdata = await proc.communicate()
  636. self.assertFalse(stdoutdata)
  637. self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
  638. # And that it exited w/ the correct code
  639. self.assertEqual(proc.returncode, 2)
  640. class TestNoiseFowarder(unittest.TestCase):
  641. def setUp(self):
  642. # setup temporary directory
  643. d = os.path.realpath(tempfile.mkdtemp())
  644. self.basetempdir = d
  645. self.tempdir = os.path.join(d, 'subdir')
  646. os.mkdir(self.tempdir)
  647. # Generate key pairs
  648. self.server_key_pair = genkeypair()
  649. self.client_key_pair = genkeypair()
  650. def tearDown(self):
  651. shutil.rmtree(self.basetempdir)
  652. self.tempdir = None
  653. @async_test
  654. async def test_clientkeymissmatch(self):
  655. # generate a key that is incorrect
  656. wrongclient_key_pair = genkeypair()
  657. # the secure socket
  658. clssockapair, clssockbpair = _asyncsockpair()
  659. reader, writer = await clssockapair
  660. async def wrongkey(v):
  661. raise ValueError('no key matches')
  662. # create the server
  663. servnf = asyncio.create_task(NoiseForwarder('resp',
  664. clssockbpair, wrongkey,
  665. priv_key=self.server_key_pair[1]))
  666. # Create client
  667. proto = NoiseConnection.from_name(
  668. b'Noise_XK_448_ChaChaPoly_SHA256')
  669. proto.set_as_initiator()
  670. # Setup wrong client key
  671. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  672. wrongclient_key_pair[1])
  673. # but the correct server key
  674. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  675. self.server_key_pair[0])
  676. proto.start_handshake()
  677. # Send first message
  678. message = proto.write_message()
  679. self.assertEqual(len(message), _handshakelens[0])
  680. writer.write(message)
  681. # Get response
  682. respmsg = await reader.readexactly(_handshakelens[1])
  683. proto.read_message(respmsg)
  684. # Send final reply
  685. message = proto.write_message()
  686. writer.write(message)
  687. # Make sure handshake has completed
  688. self.assertTrue(proto.handshake_finished)
  689. with self.assertRaises(ValueError):
  690. await servnf
  691. writer.close()
  692. @async_test
  693. async def test_server(self):
  694. # Test is plumbed:
  695. # (reader, writer) -> servsock ->
  696. # (rdr, wrr) NoiseForward (reader, writer) ->
  697. # servptsock -> (ptsock[0], ptsock[1])
  698. # Path that the server will sit on
  699. servsockpath = os.path.join(self.tempdir, 'servsock')
  700. servarg = _makeunix(servsockpath)
  701. # Path that the server will send pt data to
  702. servptpath = os.path.join(self.tempdir, 'servptsock')
  703. # Setup pt target listener
  704. pttarg = _makeunix(servptpath)
  705. ptsock = []
  706. ptsockevent = asyncio.Event()
  707. def ptsockaccept(reader, writer, ptsock=ptsock):
  708. ptsock.append((reader, writer))
  709. ptsockevent.set()
  710. # Bind to pt listener
  711. lsock = await listensockstr(pttarg, ptsockaccept)
  712. nfs = []
  713. event = asyncio.Event()
  714. async def runnf(rdr, wrr):
  715. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  716. a = await NoiseForwarder('resp',
  717. _makefut((rdr, wrr)), lambda x: ptpairfun,
  718. priv_key=self.server_key_pair[1])
  719. nfs.append(a)
  720. event.set()
  721. # Setup server listener
  722. ssock = await listensockstr(servarg, runnf)
  723. # Connect to server
  724. reader, writer = await connectsockstr(servarg)
  725. # Create client
  726. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  727. proto.set_as_initiator()
  728. # Setup required keys
  729. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  730. self.client_key_pair[1])
  731. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  732. self.server_key_pair[0])
  733. proto.start_handshake()
  734. # Send first message
  735. message = proto.write_message()
  736. self.assertEqual(len(message), _handshakelens[0])
  737. writer.write(message)
  738. # Get response
  739. respmsg = await reader.readexactly(_handshakelens[1])
  740. proto.read_message(respmsg)
  741. # Send final reply
  742. message = proto.write_message()
  743. writer.write(message)
  744. # Make sure handshake has completed
  745. self.assertTrue(proto.handshake_finished)
  746. # generate the keys for lengths
  747. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  748. b'toresp')
  749. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  750. b'toinit')
  751. # write a test message
  752. ptmsg = b'this is a test message that should be a little in length'
  753. encmsg = proto.encrypt(ptmsg)
  754. writer.write(enclenfun(encmsg))
  755. writer.write(encmsg)
  756. # wait for the connection to arrive
  757. await ptsockevent.wait()
  758. ptreader, ptwriter = ptsock[0]
  759. # read the test message
  760. rptmsg = await ptreader.readexactly(len(ptmsg))
  761. self.assertEqual(rptmsg, ptmsg)
  762. # write a different message
  763. ptmsg = os.urandom(2843)
  764. encmsg = proto.encrypt(ptmsg)
  765. writer.write(enclenfun(encmsg))
  766. writer.write(encmsg)
  767. # read the test message
  768. rptmsg = await ptreader.readexactly(len(ptmsg))
  769. self.assertEqual(rptmsg, ptmsg)
  770. # now try the other way
  771. ptmsg = os.urandom(912)
  772. ptwriter.write(ptmsg)
  773. # find out how much we need to read
  774. encmsg = await reader.readexactly(2 + 16)
  775. tlen = declenfun(encmsg)
  776. # read the rest of the message
  777. rencmsg = await reader.readexactly(tlen - 16)
  778. tmsg = encmsg[2:] + rencmsg
  779. rptmsg = proto.decrypt(tmsg)
  780. self.assertEqual(rptmsg, ptmsg)
  781. # shut down sending
  782. writer.write_eof()
  783. # so pt reader should be shut down
  784. self.assertEqual(b'', await ptreader.read(1))
  785. self.assertTrue(ptreader.at_eof())
  786. # shut down pt
  787. ptwriter.write_eof()
  788. # make sure the enc reader is eof
  789. self.assertEqual(b'', await reader.read(1))
  790. self.assertTrue(reader.at_eof())
  791. await event.wait()
  792. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  793. writer.close()
  794. ptwriter.close()
  795. lsock.close()
  796. ssock.close()
  797. await lsock.wait_closed()
  798. await ssock.wait_closed()
  799. @async_test
  800. async def test_serverclient(self):
  801. # plumbing:
  802. #
  803. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  804. #
  805. ptcsockapair, ptcsockbpair = _asyncsockpair()
  806. ptcareader, ptcawriter = await ptcsockapair
  807. #ptcsockbpair passed directly
  808. clssockapair, clssockbpair = _asyncsockpair()
  809. #both passed directly
  810. ptssockapair, ptssockbpair = _asyncsockpair()
  811. #ptssockapair passed directly
  812. ptsbreader, ptsbwriter = await ptssockbpair
  813. async def validateclientkey(pubkey):
  814. self.assertEqual(pubkey, self.client_key_pair[0])
  815. return await ptssockapair
  816. clientnf = asyncio.create_task(NoiseForwarder('init',
  817. clssockapair, lambda x: ptcsockbpair,
  818. priv_key=self.client_key_pair[1],
  819. pub_key=self.server_key_pair[0]))
  820. servnf = asyncio.create_task(NoiseForwarder('resp',
  821. clssockbpair, validateclientkey,
  822. priv_key=self.server_key_pair[1]))
  823. # send a message
  824. msga = os.urandom(183)
  825. ptcawriter.write(msga)
  826. # make sure we get the same message
  827. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  828. # send a second message
  829. msga = os.urandom(2834)
  830. ptcawriter.write(msga)
  831. # make sure we get the same message
  832. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  833. # send a message larger than the block size
  834. msga = os.urandom(103958)
  835. ptcawriter.write(msga)
  836. # make sure we get the same message
  837. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  838. # send a message the other direction
  839. msga = os.urandom(103958)
  840. ptsbwriter.write(msga)
  841. # make sure we get the same message
  842. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  843. # close down the pt writers, the rest should follow
  844. ptsbwriter.write_eof()
  845. ptcawriter.write_eof()
  846. # make sure they are closed, and there is no more data
  847. self.assertEqual(b'', await ptsbreader.read(1))
  848. self.assertTrue(ptsbreader.at_eof())
  849. self.assertEqual(b'', await ptcareader.read(1))
  850. self.assertTrue(ptcareader.at_eof())
  851. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  852. self.assertEqual([ 'dec', 'enc' ], await servnf)
  853. await ptsbwriter.drain()
  854. await ptcawriter.drain()
  855. ptsbwriter.close()
  856. ptcawriter.close()