An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1345 lines
36 KiB

  1. from contextlib import asynccontextmanager
  2. from cryptography.hazmat.backends import default_backend
  3. from cryptography.hazmat.primitives import hashes
  4. from cryptography.hazmat.primitives import serialization
  5. from cryptography.hazmat.primitives.asymmetric import x448
  6. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  7. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  8. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  9. from noise.connection import NoiseConnection, Keypair
  10. #import tracemalloc; tracemalloc.start(100)
  11. import argparse
  12. import asyncio
  13. import base64
  14. import os.path
  15. import shutil
  16. import socket
  17. import sys
  18. import tempfile
  19. import time
  20. import threading
  21. import unittest
  22. _backend = default_backend()
  23. def loadprivkey(fname):
  24. with open(fname, encoding='ascii') as fp:
  25. data = fp.read().encode('ascii')
  26. key = load_pem_private_key(data, password=None, backend=default_backend())
  27. return key
  28. def loadprivkeyraw(fname):
  29. key = loadprivkey(fname)
  30. enc = serialization.Encoding.Raw
  31. privformat = serialization.PrivateFormat.Raw
  32. encalgo = serialization.NoEncryption()
  33. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  34. def loadpubkeyraw(fname):
  35. with open(fname, encoding='ascii') as fp:
  36. lines = fp.readlines()
  37. # XXX
  38. #self.assertEqual(len(lines), 1)
  39. keytype, keyvalue = lines[0].split()
  40. if keytype != 'ntun-x448':
  41. raise RuntimeError
  42. return base64.urlsafe_b64decode(keyvalue)
  43. def genkeypair():
  44. '''Generates a keypair, and returns a tuple of (public, private).
  45. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  46. key = x448.X448PrivateKey.generate()
  47. enc = serialization.Encoding.Raw
  48. pubformat = serialization.PublicFormat.Raw
  49. privformat = serialization.PrivateFormat.Raw
  50. encalgo = serialization.NoEncryption()
  51. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  52. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  53. return pub, priv
  54. def _makefut(obj):
  55. loop = asyncio.get_running_loop()
  56. fut = loop.create_future()
  57. fut.set_result(obj)
  58. return fut
  59. def _makeunix(path):
  60. '''Make a properly formed unix path socket string.'''
  61. return 'unix:%s' % path
  62. # Make sure any additions are reflected by tests in test_parsesockstr
  63. _allowedparameters = {
  64. 'unix': {
  65. 'path': str,
  66. },
  67. 'tcp': {
  68. 'host': str,
  69. 'port': int,
  70. },
  71. }
  72. def parsesockstr(sockstr):
  73. '''Parse a socket string to its parts.
  74. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  75. If the proto has a default parameter, the value can be used
  76. directly, like: 'proto:value'. This is only allowed when the
  77. value can unambiguously be determined not to be a param. If
  78. there needs to be an equals '=', then you MUST use the extended
  79. version.
  80. The characters that define 'param' must be all lower case ascii
  81. characters and may contain an underscore. The first character
  82. must not be an underscore.
  83. Supported protocols:
  84. unix:
  85. Default parameter is path.
  86. The path parameter specifies the path to the
  87. unix domain socket. The path MUST start w/ a
  88. slash if it is used as a default parameter.
  89. tcp:
  90. Default parameter is host[:port].
  91. The host parameter specifies the host, and the
  92. port parameter specifies the port of the
  93. connection.
  94. '''
  95. proto, rem = sockstr.split(':', 1)
  96. if '=' not in rem:
  97. if proto == 'unix' and rem[0] != '/':
  98. raise ValueError('bare path MUST start w/ a slash (/).')
  99. if proto == 'unix':
  100. args = { 'path': rem }
  101. else:
  102. args = dict(i.split('=', 1) for i in rem.split(','))
  103. try:
  104. allowed = _allowedparameters[proto]
  105. except KeyError:
  106. raise ValueError('unsupported proto: %s' % repr(proto))
  107. extrakeys = args.keys() - allowed.keys()
  108. if extrakeys:
  109. raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
  110. for i in args:
  111. args[i] = allowed[i](args[i])
  112. return proto, args
  113. async def connectsockstr(sockstr):
  114. '''Wrapper for asyncio.open_*_connection.'''
  115. proto, args = parsesockstr(sockstr)
  116. if proto == 'unix':
  117. fun = asyncio.open_unix_connection
  118. elif proto == 'tcp':
  119. fun = asyncio.open_connection
  120. reader, writer = await fun(**args)
  121. return reader, writer
  122. async def listensockstr(sockstr, cb):
  123. '''Wrapper for asyncio.start_x_server.
  124. For the format of sockstr, please see parsesockstr.
  125. The cb parameter is passed to asyncio's start_server or related
  126. calls. Per those docs, the cb parameter is calls or scheduled
  127. as a task when a client establishes a connection. It is called
  128. with two arguments, the reader and writer streams. For more
  129. information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
  130. '''
  131. proto, args = parsesockstr(sockstr)
  132. if proto == 'unix':
  133. fun = asyncio.start_unix_server
  134. elif proto == 'tcp':
  135. fun = asyncio.start_server
  136. return await fun(cb, **args)
  137. # !!python makemessagelengths.py
  138. _handshakelens = \
  139. [72, 72, 88]
  140. def _genciphfun(hash, ad):
  141. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  142. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  143. key = hkdf.derive(hash)
  144. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  145. backend=_backend)
  146. enctor = cipher.encryptor()
  147. def encfun(data):
  148. # Returns the two bytes for length
  149. val = len(data)
  150. encbytes = enctor.update(data[:16])
  151. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  152. return (val ^ mask).to_bytes(length=2, byteorder='big')
  153. def decfun(data):
  154. # takes off the data and returns the total
  155. # length
  156. val = int.from_bytes(data[:2], byteorder='big')
  157. encbytes = enctor.update(data[2:2 + 16])
  158. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  159. return val ^ mask
  160. return encfun, decfun
  161. async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
  162. '''A function that forwards data between the plain text pair of
  163. streams to the encrypted session.
  164. The mode paramater must be one of 'init' or 'resp' for initiator
  165. and responder.
  166. The encrdrwrr is an await object that will return a tunle of the
  167. reader and writer streams for the encrypted side of the
  168. connection.
  169. The ptpairfun parameter is a function that will be passed the
  170. public key bytes for the remote client. This can be used to
  171. both validate that the correct client is connecting, and to
  172. pass back the correct plain text reader/writer objects that
  173. match the provided static key. The function must be an async
  174. function.
  175. In the case of the initiator, pub_key must be provided and will
  176. be used to authenticate the responder side of the connection.
  177. The priv_key parameter is used to authenticate this side of the
  178. session.
  179. Both priv_key and pub_key parameters must be 56 bytes. For example,
  180. the pair that is returned by genkeypair.
  181. '''
  182. # Send a protocol version so that in the future we can change how
  183. # we interface, and possibly be able to send control messages,
  184. # allow the client to pass some misc data to the callback, or to
  185. # allow a reverse tunnel, were the client talks to the server,
  186. # and waits for the server to "connect" to the client w/ a
  187. # connection, e.g. reverse tunnel out behind a nat to allow
  188. # incoming connections.
  189. protocol_version = 0
  190. rdr, wrr = await encrdrwrr
  191. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  192. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  193. if pub_key is not None:
  194. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  195. pub_key)
  196. if mode == 'resp':
  197. proto.set_as_responder()
  198. proto.start_handshake()
  199. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  200. wrr.write(proto.write_message())
  201. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  202. elif mode == 'init':
  203. proto.set_as_initiator()
  204. proto.start_handshake()
  205. wrr.write(proto.write_message())
  206. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  207. wrr.write(proto.write_message())
  208. if not proto.handshake_finished: # pragma: no cover
  209. raise RuntimeError('failed to finish handshake')
  210. try:
  211. reader, writer = await ptpairfun(getattr(proto.get_keypair(
  212. Keypair.REMOTE_STATIC), 'public_bytes', None))
  213. except:
  214. wrr.close()
  215. raise
  216. # generate the keys for lengths
  217. # XXX - get_handshake_hash is probably not the best option, but
  218. # this is only to obscure lengths, it is not required to be secure
  219. # as the underlying NoiseProtocol securely validates everything.
  220. # It is marginally useful as writing patterns likely expose the
  221. # true length. Adding padding could marginally help w/ this.
  222. if mode == 'resp':
  223. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  224. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  225. elif mode == 'init':
  226. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  227. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  228. # protocol negotiation
  229. # send first, then wait for the response
  230. pvmsg = protocol_version.to_bytes(1, byteorder='big')
  231. encmsg = proto.encrypt(pvmsg)
  232. wrr.write(enclenfun(encmsg))
  233. wrr.write(encmsg)
  234. # get the protocol version
  235. msg = await rdr.readexactly(2 + 16)
  236. tlen = declenfun(msg)
  237. rmsg = await rdr.readexactly(tlen - 16)
  238. tmsg = msg[2:] + rmsg
  239. rpv = proto.decrypt(tmsg)
  240. rempv = int.from_bytes(rpv, byteorder='big')
  241. if rempv != protocol_version:
  242. raise RuntimeError('unsupported protovol version received: %d' %
  243. rempv)
  244. async def decses():
  245. try:
  246. while True:
  247. try:
  248. msg = await rdr.readexactly(2 + 16)
  249. except asyncio.streams.IncompleteReadError:
  250. if rdr.at_eof():
  251. return 'dec'
  252. tlen = declenfun(msg)
  253. rmsg = await rdr.readexactly(tlen - 16)
  254. tmsg = msg[2:] + rmsg
  255. writer.write(proto.decrypt(tmsg))
  256. await writer.drain()
  257. #except:
  258. # import traceback
  259. # traceback.print_exc()
  260. # raise
  261. finally:
  262. try:
  263. writer.write_eof()
  264. except OSError as e:
  265. if e.errno != 57:
  266. raise
  267. async def encses():
  268. try:
  269. while True:
  270. # largest message
  271. ptmsg = await reader.read(65535 - 16)
  272. if not ptmsg:
  273. # eof
  274. return 'enc'
  275. encmsg = proto.encrypt(ptmsg)
  276. wrr.write(enclenfun(encmsg))
  277. wrr.write(encmsg)
  278. await wrr.drain()
  279. #except:
  280. # import traceback
  281. # traceback.print_exc()
  282. # raise
  283. finally:
  284. wrr.write_eof()
  285. res = await asyncio.gather(decses(), encses())
  286. await wrr.drain() # not sure if needed
  287. wrr.close()
  288. await writer.drain() # not sure if needed
  289. writer.close()
  290. return res
  291. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  292. # Slightly modified to timeout and to print trace back when canceled.
  293. # This makes it easier to figure out what "froze".
  294. def async_test(f):
  295. def wrapper(*args, **kwargs):
  296. async def tbcapture():
  297. try:
  298. return await f(*args, **kwargs)
  299. except asyncio.CancelledError as e:
  300. # if we are going to be cancelled, print out a tb
  301. import traceback
  302. traceback.print_exc()
  303. raise
  304. loop = asyncio.get_event_loop()
  305. # timeout after 4 seconds
  306. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  307. return wrapper
  308. class Tests_misc(unittest.TestCase):
  309. def setUp(self):
  310. # setup temporary directory
  311. d = os.path.realpath(tempfile.mkdtemp())
  312. self.basetempdir = d
  313. self.tempdir = os.path.join(d, 'subdir')
  314. os.mkdir(self.tempdir)
  315. os.chdir(self.tempdir)
  316. def tearDown(self):
  317. #print('td:', time.time())
  318. shutil.rmtree(self.basetempdir)
  319. self.tempdir = None
  320. def test_parsesockstr_bad(self):
  321. badstrs = [
  322. 'unix:ff',
  323. 'randomnocolon',
  324. 'unix:somethingelse=bogus',
  325. 'tcp:port=bogus',
  326. ]
  327. for i in badstrs:
  328. with self.assertRaises(ValueError,
  329. msg='Should have failed processing: %s' % repr(i)):
  330. parsesockstr(i)
  331. def test_parsesockstr(self):
  332. results = {
  333. # Not all of these are valid when passed to a *sockstr
  334. # function
  335. 'unix:/apath': ('unix', { 'path': '/apath' }),
  336. 'unix:path=apath': ('unix', { 'path': 'apath' }),
  337. 'tcp:host=apath': ('tcp', { 'host': 'apath' }),
  338. 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
  339. 'port': 5 }),
  340. }
  341. for s, r in results.items():
  342. self.assertEqual(parsesockstr(s), r)
  343. @async_test
  344. async def test_listensockstr_bad(self):
  345. with self.assertRaises(ValueError):
  346. ls = await listensockstr('bogus:some=arg', None)
  347. with self.assertRaises(ValueError):
  348. ls = await connectsockstr('bogus:some=arg')
  349. @async_test
  350. async def test_listenconnectsockstr(self):
  351. msgsent = b'this is a test message'
  352. msgrcv = b'testing message for receive'
  353. # That when a connection is received and receives and sends
  354. async def servconfhandle(rdr, wrr):
  355. msg = await rdr.readexactly(len(msgsent))
  356. self.assertEqual(msg, msgsent)
  357. #print(repr(wrr.get_extra_info('sockname')))
  358. wrr.write(msgrcv)
  359. await wrr.drain()
  360. wrr.close()
  361. return True
  362. # Test listensockstr
  363. for sstr, confun in [
  364. ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
  365. ('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
  366. ]:
  367. # that listensockstr will bind to the correct path, can call cb
  368. ls = await listensockstr(sstr, servconfhandle)
  369. # that we open a connection to the path
  370. rdr, wrr = await confun()
  371. # and send a message
  372. wrr.write(msgsent)
  373. # and receive the message
  374. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  375. self.assertEqual(rcv, msgrcv)
  376. wrr.close()
  377. # Now test that connectsockstr works similarly.
  378. rdr, wrr = await connectsockstr(sstr)
  379. # and send a message
  380. wrr.write(msgsent)
  381. # and receive the message
  382. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  383. self.assertEqual(rcv, msgrcv)
  384. wrr.close()
  385. ls.close()
  386. await ls.wait_closed()
  387. def test_genciphfun(self):
  388. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  389. msg = b'this is a bunch of data'
  390. tb = enc(msg)
  391. self.assertEqual(len(msg), dec(tb + msg))
  392. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  393. msg = os.urandom(i)
  394. self.assertEqual(len(msg), dec(enc(msg) + msg))
  395. def cmd_client(args):
  396. privkey = loadprivkeyraw(args.clientkey)
  397. pubkey = loadpubkeyraw(args.servkey)
  398. async def runnf(rdr, wrr):
  399. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  400. a = await NoiseForwarder('init',
  401. encpair, lambda x: _makefut((rdr, wrr)),
  402. priv_key=privkey, pub_key=pubkey)
  403. # Setup client listener
  404. ssock = listensockstr(args.clientlisten, runnf)
  405. loop = asyncio.get_event_loop()
  406. obj = loop.run_until_complete(ssock)
  407. loop.run_until_complete(obj.serve_forever())
  408. def cmd_server(args):
  409. privkey = loadprivkeyraw(args.servkey)
  410. pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
  411. async def runnf(rdr, wrr):
  412. async def checkclientfun(clientkey):
  413. if clientkey not in pubkeys:
  414. raise RuntimeError('invalid key provided')
  415. return await connectsockstr(args.servtarget)
  416. a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
  417. checkclientfun, priv_key=privkey)
  418. # Setup server listener
  419. ssock = listensockstr(args.servlisten, runnf)
  420. loop = asyncio.get_event_loop()
  421. obj = loop.run_until_complete(ssock)
  422. loop.run_until_complete(obj.serve_forever())
  423. def cmd_genkey(args):
  424. keypair = genkeypair()
  425. key = x448.X448PrivateKey.generate()
  426. # public key part
  427. enc = serialization.Encoding.Raw
  428. pubformat = serialization.PublicFormat.Raw
  429. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  430. try:
  431. fname = args.fname + '.pub'
  432. with open(fname, 'x', encoding='ascii') as fp:
  433. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  434. except FileExistsError:
  435. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  436. sys.exit(2)
  437. enc = serialization.Encoding.PEM
  438. format = serialization.PrivateFormat.PKCS8
  439. encalgo = serialization.NoEncryption()
  440. with open(args.fname, 'x', encoding='ascii') as fp:
  441. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  442. def main():
  443. parser = argparse.ArgumentParser()
  444. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  445. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  446. parser_gk.add_argument('fname', type=str, help='file name for the key')
  447. parser_gk.set_defaults(func=cmd_genkey)
  448. parser_serv = subparsers.add_parser('server', help='run a server')
  449. parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  450. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  451. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  452. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  453. parser_serv.set_defaults(func=cmd_server)
  454. parser_client = subparsers.add_parser('client', help='run a client')
  455. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  456. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  457. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  458. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  459. parser_client.set_defaults(func=cmd_client)
  460. args = parser.parse_args()
  461. try:
  462. fun = args.func
  463. except AttributeError:
  464. parser.print_usage()
  465. sys.exit(5)
  466. fun(args)
  467. if __name__ == '__main__': # pragma: no cover
  468. main()
  469. def _asyncsockpair():
  470. '''Create a pair of sockets that are bound to each other.
  471. The function will return a tuple of two coroutine's, that
  472. each, when await'ed upon, will return the reader/writer pair.'''
  473. socka, sockb = socket.socketpair()
  474. return asyncio.open_connection(sock=socka), \
  475. asyncio.open_connection(sock=sockb)
  476. async def _awaitfile(fname):
  477. while not os.path.exists(fname):
  478. await asyncio.sleep(.01)
  479. return True
  480. class TestMain(unittest.TestCase):
  481. def setUp(self):
  482. # setup temporary directory
  483. d = os.path.realpath(tempfile.mkdtemp())
  484. self.basetempdir = d
  485. self.tempdir = os.path.join(d, 'subdir')
  486. os.mkdir(self.tempdir)
  487. # Generate key pairs
  488. self.server_key_pair = genkeypair()
  489. self.client_key_pair = genkeypair()
  490. os.chdir(self.tempdir)
  491. def tearDown(self):
  492. #print('td:', time.time())
  493. shutil.rmtree(self.basetempdir)
  494. self.tempdir = None
  495. @asynccontextmanager
  496. async def run_with_args(self, *args, pipes=True):
  497. kwargs = {}
  498. if pipes:
  499. kwargs.update(dict(
  500. stdout=asyncio.subprocess.PIPE,
  501. stderr=asyncio.subprocess.PIPE))
  502. aproc = asyncio.create_subprocess_exec(sys.executable,
  503. # XXX - figure out how to add coverage data on these runs
  504. #'-m', 'coverage', 'run', '-p',
  505. __file__, *args, **kwargs)
  506. try:
  507. proc = await aproc
  508. yield proc
  509. finally:
  510. if proc.returncode is None:
  511. proc.terminate()
  512. # Make sure that process exits before continuing
  513. await proc.wait()
  514. @async_test
  515. async def test_noargs(self):
  516. async with self.run_with_args() as proc:
  517. await proc.wait()
  518. # XXX - not checking error message
  519. # And that it exited w/ the correct code
  520. self.assertEqual(proc.returncode, 5)
  521. async def genkey(self, name):
  522. async with self.run_with_args('genkey', name, pipes=False) as proc:
  523. await proc.wait()
  524. self.assertEqual(proc.returncode, 0)
  525. @async_test
  526. async def test_loadpubkey(self):
  527. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  528. await self.genkey(keypath)
  529. privkey = loadprivkey(keypath)
  530. enc = serialization.Encoding.Raw
  531. pubformat = serialization.PublicFormat.Raw
  532. pubkeybytes = privkey.public_key().public_bytes(encoding=enc,
  533. format=pubformat)
  534. pubkey = loadpubkeyraw(keypath + '.pub')
  535. self.assertEqual(pubkeybytes, pubkey)
  536. privrawkey = loadprivkeyraw(keypath)
  537. enc = serialization.Encoding.Raw
  538. privformat = serialization.PrivateFormat.Raw
  539. encalgo = serialization.NoEncryption()
  540. rprivrawkey = privkey.private_bytes(encoding=enc,
  541. format=privformat, encryption_algorithm=encalgo)
  542. self.assertEqual(rprivrawkey, privrawkey)
  543. @async_test
  544. async def test_clientkeymismatch(self):
  545. # make sure that if there's a client key mismatch, we
  546. # don't connect
  547. # Generate necessar keys
  548. servkeypath = os.path.join(self.tempdir, 'server_key')
  549. await self.genkey(servkeypath)
  550. clientkeypath = os.path.join(self.tempdir, 'client_key')
  551. await self.genkey(clientkeypath)
  552. badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
  553. await self.genkey(badclientkeypath)
  554. # forwards connectsion to this socket (created by client)
  555. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  556. ptclientstr = _makeunix(ptclientpath)
  557. # this is the socket server listen to
  558. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  559. incservstr = _makeunix(incservpath)
  560. # to this socket, opened by server
  561. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  562. servtargstr = _makeunix(servtargpath)
  563. # Setup server target listener
  564. ptsockevent = asyncio.Event()
  565. # Bind to pt listener
  566. lsock = await listensockstr(servtargstr, None)
  567. # Startup the server
  568. wserver = self.run_with_args('server',
  569. '-c', clientkeypath + '.pub',
  570. servkeypath, incservstr, servtargstr)
  571. # Startup the client with the "bad" key
  572. wclient = self.run_with_args('client', badclientkeypath,
  573. servkeypath + '.pub', ptclientstr, incservstr)
  574. async with wserver as server, wclient as client:
  575. # wait for server target to be created
  576. await _awaitfile(servtargpath)
  577. # wait for server to start
  578. await _awaitfile(incservpath)
  579. # wait for client to start
  580. await _awaitfile(ptclientpath)
  581. # Connect to the client
  582. reader, writer = await connectsockstr(ptclientstr)
  583. # XXX - this might not be the best test.
  584. with self.assertRaises(asyncio.futures.TimeoutError):
  585. # make sure that we don't get the conenction
  586. await asyncio.wait_for(ptsockevent.wait(), .5)
  587. writer.close()
  588. # Make sure that when the server is terminated
  589. server.terminate()
  590. # that it's stderr
  591. stdout, stderr = await server.communicate()
  592. #print('s:', repr((stdout, stderr)))
  593. # doesn't have an exceptions never retrieved
  594. # even the example echo server has this same leak
  595. #self.assertNotIn(b'Task exception was never retrieved', stderr)
  596. lsock.close()
  597. await lsock.wait_closed()
  598. # Kill off the client
  599. client.terminate()
  600. stdout, stderr = await client.communicate()
  601. #print('s:', repr((stdout, stderr)))
  602. # XXX - figure out how to clean up client properly
  603. @async_test
  604. async def test_end2end(self):
  605. # Generate necessar keys
  606. servkeypath = os.path.join(self.tempdir, 'server_key')
  607. await self.genkey(servkeypath)
  608. clientkeypath = os.path.join(self.tempdir, 'client_key')
  609. await self.genkey(clientkeypath)
  610. # forwards connectsion to this socket (created by client)
  611. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  612. ptclientstr = _makeunix(ptclientpath)
  613. # this is the socket server listen to
  614. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  615. incservstr = _makeunix(incservpath)
  616. # to this socket, opened by server
  617. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  618. servtargstr = _makeunix(servtargpath)
  619. # Setup server target listener
  620. ptsock = []
  621. ptsockevent = asyncio.Event()
  622. def ptsockaccept(reader, writer, ptsock=ptsock):
  623. ptsock.append((reader, writer))
  624. ptsockevent.set()
  625. # Bind to pt listener
  626. lsock = await listensockstr(servtargstr, ptsockaccept)
  627. # Startup the server
  628. wserver = self.run_with_args('server',
  629. '-c', clientkeypath + '.pub',
  630. servkeypath, incservstr, servtargstr,
  631. pipes=False)
  632. # Startup the client
  633. wclient = self.run_with_args('client',
  634. clientkeypath, servkeypath + '.pub', ptclientstr,
  635. incservstr, pipes=False)
  636. async with wserver as server, wclient as client:
  637. # wait for server target to be created
  638. await _awaitfile(servtargpath)
  639. # wait for server to start
  640. await _awaitfile(incservpath)
  641. # wait for client to start
  642. await _awaitfile(ptclientpath)
  643. # Connect to the client
  644. reader, writer = await connectsockstr(ptclientstr)
  645. # send a message
  646. ptmsg = b'this is a message for testing'
  647. writer.write(ptmsg)
  648. # make sure that we got the conenction
  649. await ptsockevent.wait()
  650. # get the connection
  651. endrdr, endwrr = ptsock[0]
  652. # make sure we can read back what we sent
  653. self.assertEqual(ptmsg,
  654. await endrdr.readexactly(len(ptmsg)))
  655. # test some additional messages
  656. for i in [ 129, 1287, 28792, 129872 ]:
  657. # in on direction
  658. msg = os.urandom(i)
  659. writer.write(msg)
  660. self.assertEqual(msg,
  661. await endrdr.readexactly(len(msg)))
  662. # and the other
  663. endwrr.write(msg)
  664. self.assertEqual(msg,
  665. await reader.readexactly(len(msg)))
  666. writer.close()
  667. endwrr.close()
  668. lsock.close()
  669. await lsock.wait_closed()
  670. # XXX - more testing that things exited properly
  671. @async_test
  672. async def test_genkey(self):
  673. # that it can generate a key
  674. async with self.run_with_args('genkey', 'somefile') as proc:
  675. await proc.wait()
  676. #print(await proc.communicate())
  677. self.assertEqual(proc.returncode, 0)
  678. with open('somefile.pub', encoding='ascii') as fp:
  679. lines = fp.readlines()
  680. self.assertEqual(len(lines), 1)
  681. keytype, keyvalue = lines[0].split()
  682. self.assertEqual(keytype, 'ntun-x448')
  683. key = x448.X448PublicKey.from_public_bytes(
  684. base64.urlsafe_b64decode(keyvalue))
  685. key = loadprivkey('somefile')
  686. self.assertIsInstance(key, x448.X448PrivateKey)
  687. # that a second call fails
  688. async with self.run_with_args('genkey', 'somefile') as proc:
  689. await proc.wait()
  690. stdoutdata, stderrdata = await proc.communicate()
  691. self.assertFalse(stdoutdata)
  692. self.assertEqual(
  693. b'failed to create somefile.pub, file exists.\n',
  694. stderrdata)
  695. # And that it exited w/ the correct code
  696. self.assertEqual(proc.returncode, 2)
  697. class TestNoiseFowarder(unittest.TestCase):
  698. def setUp(self):
  699. # setup temporary directory
  700. d = os.path.realpath(tempfile.mkdtemp())
  701. self.basetempdir = d
  702. self.tempdir = os.path.join(d, 'subdir')
  703. os.mkdir(self.tempdir)
  704. # Generate key pairs
  705. self.server_key_pair = genkeypair()
  706. self.client_key_pair = genkeypair()
  707. def tearDown(self):
  708. shutil.rmtree(self.basetempdir)
  709. self.tempdir = None
  710. @async_test
  711. async def test_clientkeymissmatch(self):
  712. # generate a key that is incorrect
  713. wrongclient_key_pair = genkeypair()
  714. # the secure socket
  715. clssockapair, clssockbpair = _asyncsockpair()
  716. reader, writer = await clssockapair
  717. async def wrongkey(v):
  718. raise ValueError('no key matches')
  719. # create the server
  720. servnf = asyncio.create_task(NoiseForwarder('resp',
  721. clssockbpair, wrongkey,
  722. priv_key=self.server_key_pair[1]))
  723. # Create client
  724. proto = NoiseConnection.from_name(
  725. b'Noise_XK_448_ChaChaPoly_SHA256')
  726. proto.set_as_initiator()
  727. # Setup wrong client key
  728. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  729. wrongclient_key_pair[1])
  730. # but the correct server key
  731. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  732. self.server_key_pair[0])
  733. proto.start_handshake()
  734. # Send first message
  735. message = proto.write_message()
  736. self.assertEqual(len(message), _handshakelens[0])
  737. writer.write(message)
  738. # Get response
  739. respmsg = await reader.readexactly(_handshakelens[1])
  740. proto.read_message(respmsg)
  741. # Send final reply
  742. message = proto.write_message()
  743. writer.write(message)
  744. # Make sure handshake has completed
  745. self.assertTrue(proto.handshake_finished)
  746. with self.assertRaises(ValueError):
  747. await servnf
  748. writer.close()
  749. @async_test
  750. async def test_server(self):
  751. # Test is plumbed:
  752. # (reader, writer) -> servsock ->
  753. # (rdr, wrr) NoiseForward (reader, writer) ->
  754. # servptsock -> (ptsock[0], ptsock[1])
  755. # Path that the server will sit on
  756. servsockpath = os.path.join(self.tempdir, 'servsock')
  757. servarg = _makeunix(servsockpath)
  758. # Path that the server will send pt data to
  759. servptpath = os.path.join(self.tempdir, 'servptsock')
  760. # Setup pt target listener
  761. pttarg = _makeunix(servptpath)
  762. ptsock = []
  763. ptsockevent = asyncio.Event()
  764. def ptsockaccept(reader, writer, ptsock=ptsock):
  765. ptsock.append((reader, writer))
  766. ptsockevent.set()
  767. # Bind to pt listener
  768. lsock = await listensockstr(pttarg, ptsockaccept)
  769. nfs = []
  770. event = asyncio.Event()
  771. async def runnf(rdr, wrr):
  772. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  773. a = await NoiseForwarder('resp',
  774. _makefut((rdr, wrr)), lambda x: ptpairfun,
  775. priv_key=self.server_key_pair[1])
  776. nfs.append(a)
  777. event.set()
  778. # Setup server listener
  779. ssock = await listensockstr(servarg, runnf)
  780. # Connect to server
  781. reader, writer = await connectsockstr(servarg)
  782. # Create client
  783. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  784. proto.set_as_initiator()
  785. # Setup required keys
  786. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  787. self.client_key_pair[1])
  788. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  789. self.server_key_pair[0])
  790. proto.start_handshake()
  791. # Send first message
  792. message = proto.write_message()
  793. self.assertEqual(len(message), _handshakelens[0])
  794. writer.write(message)
  795. # Get response
  796. respmsg = await reader.readexactly(_handshakelens[1])
  797. proto.read_message(respmsg)
  798. # Send final reply
  799. message = proto.write_message()
  800. writer.write(message)
  801. # Make sure handshake has completed
  802. self.assertTrue(proto.handshake_finished)
  803. # generate the keys for lengths
  804. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  805. b'toresp')
  806. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  807. b'toinit')
  808. pversion = 0
  809. # Send the protocol version string first
  810. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  811. writer.write(enclenfun(encmsg))
  812. writer.write(encmsg)
  813. # Read the peer's protocol version
  814. # find out how much we need to read
  815. encmsg = await reader.readexactly(2 + 16)
  816. tlen = declenfun(encmsg)
  817. # read the rest of the message
  818. rencmsg = await reader.readexactly(tlen - 16)
  819. tmsg = encmsg[2:] + rencmsg
  820. rptmsg = proto.decrypt(tmsg)
  821. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion)
  822. # write a test message
  823. ptmsg = b'this is a test message that should be a little in length'
  824. encmsg = proto.encrypt(ptmsg)
  825. writer.write(enclenfun(encmsg))
  826. writer.write(encmsg)
  827. # wait for the connection to arrive
  828. await ptsockevent.wait()
  829. ptreader, ptwriter = ptsock[0]
  830. # read the test message
  831. rptmsg = await ptreader.readexactly(len(ptmsg))
  832. self.assertEqual(rptmsg, ptmsg)
  833. # write a different message
  834. ptmsg = os.urandom(2843)
  835. encmsg = proto.encrypt(ptmsg)
  836. writer.write(enclenfun(encmsg))
  837. writer.write(encmsg)
  838. # read the test message
  839. rptmsg = await ptreader.readexactly(len(ptmsg))
  840. self.assertEqual(rptmsg, ptmsg)
  841. # now try the other way
  842. ptmsg = os.urandom(912)
  843. ptwriter.write(ptmsg)
  844. # find out how much we need to read
  845. encmsg = await reader.readexactly(2 + 16)
  846. tlen = declenfun(encmsg)
  847. # read the rest of the message
  848. rencmsg = await reader.readexactly(tlen - 16)
  849. tmsg = encmsg[2:] + rencmsg
  850. rptmsg = proto.decrypt(tmsg)
  851. self.assertEqual(rptmsg, ptmsg)
  852. # shut down sending
  853. writer.write_eof()
  854. # so pt reader should be shut down
  855. self.assertEqual(b'', await ptreader.read(1))
  856. self.assertTrue(ptreader.at_eof())
  857. # shut down pt
  858. ptwriter.write_eof()
  859. # make sure the enc reader is eof
  860. self.assertEqual(b'', await reader.read(1))
  861. self.assertTrue(reader.at_eof())
  862. await event.wait()
  863. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  864. writer.close()
  865. ptwriter.close()
  866. lsock.close()
  867. ssock.close()
  868. await lsock.wait_closed()
  869. await ssock.wait_closed()
  870. @async_test
  871. async def test_protocolversionmismatch(self):
  872. # make sure that if we send a future version, that we
  873. # still get a protocol version, and that the connection
  874. # is closed w/o establishing a connection to the remote
  875. # side
  876. # Test is plumbed:
  877. # (reader, writer) -> servsock ->
  878. # (rdr, wrr) NoiseForward (reader, writer) ->
  879. # servptsock -> (ptsock[0], ptsock[1])
  880. # Path that the server will sit on
  881. servsockpath = os.path.join(self.tempdir, 'servsock')
  882. servarg = _makeunix(servsockpath)
  883. # Path that the server will send pt data to
  884. servptpath = os.path.join(self.tempdir, 'servptsock')
  885. # Setup pt target listener
  886. pttarg = _makeunix(servptpath)
  887. ptsock = []
  888. ptsockevent = asyncio.Event()
  889. def ptsockaccept(reader, writer, ptsock=ptsock):
  890. ptsock.append((reader, writer))
  891. ptsockevent.set()
  892. # Bind to pt listener
  893. lsock = await listensockstr(pttarg, ptsockaccept)
  894. nfs = []
  895. event = asyncio.Event()
  896. async def runnf(rdr, wrr):
  897. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  898. try:
  899. a = await NoiseForwarder('resp',
  900. _makefut((rdr, wrr)), lambda x: ptpairfun,
  901. priv_key=self.server_key_pair[1])
  902. except RuntimeError as e:
  903. nfs.append(e)
  904. event.set()
  905. return
  906. nfs.append(a)
  907. event.set()
  908. # Setup server listener
  909. ssock = await listensockstr(servarg, runnf)
  910. # Connect to server
  911. reader, writer = await connectsockstr(servarg)
  912. # Create client
  913. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  914. proto.set_as_initiator()
  915. # Setup required keys
  916. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  917. self.client_key_pair[1])
  918. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  919. self.server_key_pair[0])
  920. proto.start_handshake()
  921. # Send first message
  922. message = proto.write_message()
  923. self.assertEqual(len(message), _handshakelens[0])
  924. writer.write(message)
  925. # Get response
  926. respmsg = await reader.readexactly(_handshakelens[1])
  927. proto.read_message(respmsg)
  928. # Send final reply
  929. message = proto.write_message()
  930. writer.write(message)
  931. # Make sure handshake has completed
  932. self.assertTrue(proto.handshake_finished)
  933. # generate the keys for lengths
  934. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  935. b'toresp')
  936. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  937. b'toinit')
  938. pversion = 1
  939. # Send the protocol version string first
  940. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  941. writer.write(enclenfun(encmsg))
  942. writer.write(encmsg)
  943. # Read the peer's protocol version
  944. # find out how much we need to read
  945. encmsg = await reader.readexactly(2 + 16)
  946. tlen = declenfun(encmsg)
  947. # read the rest of the message
  948. rencmsg = await reader.readexactly(tlen - 16)
  949. tmsg = encmsg[2:] + rencmsg
  950. rptmsg = proto.decrypt(tmsg)
  951. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0)
  952. await event.wait()
  953. self.assertIsInstance(nfs[0], RuntimeError)
  954. @async_test
  955. async def test_serverclient(self):
  956. # plumbing:
  957. #
  958. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  959. #
  960. ptcsockapair, ptcsockbpair = _asyncsockpair()
  961. ptcareader, ptcawriter = await ptcsockapair
  962. #ptcsockbpair passed directly
  963. clssockapair, clssockbpair = _asyncsockpair()
  964. #both passed directly
  965. ptssockapair, ptssockbpair = _asyncsockpair()
  966. #ptssockapair passed directly
  967. ptsbreader, ptsbwriter = await ptssockbpair
  968. async def validateclientkey(pubkey):
  969. self.assertEqual(pubkey, self.client_key_pair[0])
  970. return await ptssockapair
  971. clientnf = asyncio.create_task(NoiseForwarder('init',
  972. clssockapair, lambda x: ptcsockbpair,
  973. priv_key=self.client_key_pair[1],
  974. pub_key=self.server_key_pair[0]))
  975. servnf = asyncio.create_task(NoiseForwarder('resp',
  976. clssockbpair, validateclientkey,
  977. priv_key=self.server_key_pair[1]))
  978. # send a message
  979. msga = os.urandom(183)
  980. ptcawriter.write(msga)
  981. # make sure we get the same message
  982. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  983. # send a second message
  984. msga = os.urandom(2834)
  985. ptcawriter.write(msga)
  986. # make sure we get the same message
  987. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  988. # send a message larger than the block size
  989. msga = os.urandom(103958)
  990. ptcawriter.write(msga)
  991. # make sure we get the same message
  992. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  993. # send a message the other direction
  994. msga = os.urandom(103958)
  995. ptsbwriter.write(msga)
  996. # make sure we get the same message
  997. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  998. # close down the pt writers, the rest should follow
  999. ptsbwriter.write_eof()
  1000. ptcawriter.write_eof()
  1001. # make sure they are closed, and there is no more data
  1002. self.assertEqual(b'', await ptsbreader.read(1))
  1003. self.assertTrue(ptsbreader.at_eof())
  1004. self.assertEqual(b'', await ptcareader.read(1))
  1005. self.assertTrue(ptcareader.at_eof())
  1006. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  1007. self.assertEqual([ 'dec', 'enc' ], await servnf)
  1008. await ptsbwriter.drain()
  1009. await ptcawriter.drain()
  1010. ptsbwriter.close()
  1011. ptcawriter.close()