From 74ff15da8c4056c745326a2d65ef7027ea431599 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Thu, 31 Oct 2019 14:43:02 -0700 Subject: [PATCH] use an asynccontextmanager to make sure that subprocesses are terminated --- ntunnel.py | 215 ++++++++++++++++++++++++++++------------------------- 1 file changed, 114 insertions(+), 101 deletions(-) diff --git a/ntunnel.py b/ntunnel.py index 28a5e67..e8623c6 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization @@ -651,34 +652,40 @@ class TestMain(unittest.TestCase): shutil.rmtree(self.basetempdir) self.tempdir = None - @async_test - async def test_noargs(self): - proc = await self.run_with_args() - - await proc.wait() - - # XXX - not checking error message - - # And that it exited w/ the correct code - self.assertEqual(proc.returncode, 5) - - def run_with_args(self, *args, pipes=True): + @asynccontextmanager + async def run_with_args(self, *args, pipes=True): kwargs = {} if pipes: kwargs.update(dict( stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)) - return asyncio.create_subprocess_exec(sys.executable, + aproc = asyncio.create_subprocess_exec(sys.executable, # XXX - figure out how to add coverage data on these runs #'-m', 'coverage', 'run', '-p', __file__, *args, **kwargs) - async def genkey(self, name): - proc = await self.run_with_args('genkey', name, pipes=False) + try: + proc = await aproc + yield proc + finally: + if proc.returncode is None: + proc.terminate() - await proc.wait() + @async_test + async def test_noargs(self): + async with self.run_with_args() as proc: + await proc.wait() - self.assertEqual(proc.returncode, 0) + # XXX - not checking error message + + # And that it exited w/ the correct code + self.assertEqual(proc.returncode, 5) + + async def genkey(self, name): + async with self.run_with_args('genkey', name, pipes=False) as proc: + await proc.wait() + + self.assertEqual(proc.returncode, 0) @async_test async def test_loadpubkey(self): @@ -690,7 +697,8 @@ class TestMain(unittest.TestCase): enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw - pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat) + pubkeybytes = privkey.public_key().public_bytes(encoding=enc, + format=pubformat) pubkey = loadpubkeyraw(keypath + '.pub') @@ -702,7 +710,8 @@ class TestMain(unittest.TestCase): privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() - rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) + rprivrawkey = privkey.private_bytes(encoding=enc, + format=privformat, encryption_algorithm=encalgo) self.assertEqual(rprivrawkey, privrawkey) @@ -738,53 +747,54 @@ class TestMain(unittest.TestCase): lsock = await listensockstr(servtargstr, None) # Startup the server - server = await self.run_with_args('server', + wserver = self.run_with_args('server', '-c', clientkeypath + '.pub', servkeypath, incservstr, servtargstr) # Startup the client with the "bad" key - client = await self.run_with_args('client', - badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr) + wclient = self.run_with_args('client', badclientkeypath, + servkeypath + '.pub', ptclientstr, incservstr) - # wait for server target to be created - await _awaitfile(servtargpath) + async with wserver as server, wclient as client: + # wait for server target to be created + await _awaitfile(servtargpath) - # wait for server to start - await _awaitfile(incservpath) + # wait for server to start + await _awaitfile(incservpath) - # wait for client to start - await _awaitfile(ptclientpath) + # wait for client to start + await _awaitfile(ptclientpath) - # Connect to the client - reader, writer = await connectsockstr(ptclientstr) + # Connect to the client + reader, writer = await connectsockstr(ptclientstr) - # XXX - this might not be the best test. - with self.assertRaises(asyncio.futures.TimeoutError): - # make sure that we don't get the conenction - await asyncio.wait_for(ptsockevent.wait(), .5) + # XXX - this might not be the best test. + with self.assertRaises(asyncio.futures.TimeoutError): + # make sure that we don't get the conenction + await asyncio.wait_for(ptsockevent.wait(), .5) - writer.close() + writer.close() - # Make sure that when the server is terminated - server.terminate() + # Make sure that when the server is terminated + server.terminate() - # that it's stderr - stdout, stderr = await server.communicate() - #print('s:', repr((stdout, stderr))) + # that it's stderr + stdout, stderr = await server.communicate() + #print('s:', repr((stdout, stderr))) - # doesn't have an exceptions never retrieved - # even the example echo server has this same leak - #self.assertNotIn(b'Task exception was never retrieved', stderr) + # doesn't have an exceptions never retrieved + # even the example echo server has this same leak + #self.assertNotIn(b'Task exception was never retrieved', stderr) - lsock.close() - await lsock.wait_closed() + lsock.close() + await lsock.wait_closed() - # Kill off the client - client.terminate() + # Kill off the client + client.terminate() - stdout, stderr = await client.communicate() - #print('s:', repr((stdout, stderr))) - # XXX - figure out how to clean up client properly + stdout, stderr = await client.communicate() + #print('s:', repr((stdout, stderr))) + # XXX - figure out how to clean up client properly @async_test async def test_end2end(self): @@ -817,72 +827,73 @@ class TestMain(unittest.TestCase): lsock = await listensockstr(servtargstr, ptsockaccept) # Startup the server - server = await self.run_with_args('server', + wserver = self.run_with_args('server', '-c', clientkeypath + '.pub', servkeypath, incservstr, servtargstr, pipes=False) # Startup the client - client = await self.run_with_args('client', - clientkeypath, servkeypath + '.pub', ptclientstr, incservstr, - pipes=False) + wclient = self.run_with_args('client', + clientkeypath, servkeypath + '.pub', ptclientstr, + incservstr, pipes=False) - # wait for server target to be created - await _awaitfile(servtargpath) + async with wserver as server, wclient as client: + # wait for server target to be created + await _awaitfile(servtargpath) - # wait for server to start - await _awaitfile(incservpath) + # wait for server to start + await _awaitfile(incservpath) - # wait for client to start - await _awaitfile(ptclientpath) + # wait for client to start + await _awaitfile(ptclientpath) - # Connect to the client - reader, writer = await connectsockstr(ptclientstr) + # Connect to the client + reader, writer = await connectsockstr(ptclientstr) - # send a message - ptmsg = b'this is a message for testing' - writer.write(ptmsg) + # send a message + ptmsg = b'this is a message for testing' + writer.write(ptmsg) - # make sure that we got the conenction - await ptsockevent.wait() + # make sure that we got the conenction + await ptsockevent.wait() - # get the connection - endrdr, endwrr = ptsock[0] + # get the connection + endrdr, endwrr = ptsock[0] - # make sure we can read back what we sent - self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg))) + # make sure we can read back what we sent + self.assertEqual(ptmsg, + await endrdr.readexactly(len(ptmsg))) - # test some additional messages - for i in [ 129, 1287, 28792, 129872 ]: - # in on direction - msg = os.urandom(i) - writer.write(msg) - self.assertEqual(msg, await endrdr.readexactly(len(msg))) + # test some additional messages + for i in [ 129, 1287, 28792, 129872 ]: + # in on direction + msg = os.urandom(i) + writer.write(msg) + self.assertEqual(msg, + await endrdr.readexactly(len(msg))) - # and the other - endwrr.write(msg) - self.assertEqual(msg, await reader.readexactly(len(msg))) + # and the other + endwrr.write(msg) + self.assertEqual(msg, + await reader.readexactly(len(msg))) - writer.close() - endwrr.close() + writer.close() + endwrr.close() - lsock.close() - await lsock.wait_closed() + lsock.close() + await lsock.wait_closed() - server.terminate() - client.terminate() - # XXX - more clean up testing + # XXX - more testing that things exited properly @async_test async def test_genkey(self): # that it can generate a key - proc = await self.run_with_args('genkey', 'somefile') - - await proc.wait() + async with self.run_with_args('genkey', 'somefile') as proc: + await proc.wait() - #print(await proc.communicate()) + #print(await proc.communicate()) - self.assertEqual(proc.returncode, 0) + self.assertEqual(proc.returncode, 0) with open('somefile.pub', encoding='ascii') as fp: lines = fp.readlines() @@ -891,23 +902,25 @@ class TestMain(unittest.TestCase): keytype, keyvalue = lines[0].split() self.assertEqual(keytype, 'ntun-x448') - key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) + key = x448.X448PublicKey.from_public_bytes( + base64.urlsafe_b64decode(keyvalue)) key = loadprivkey('somefile') self.assertIsInstance(key, x448.X448PrivateKey) # that a second call fails - proc = await self.run_with_args('genkey', 'somefile') - - await proc.wait() + async with self.run_with_args('genkey', 'somefile') as proc: + await proc.wait() - stdoutdata, stderrdata = await proc.communicate() + stdoutdata, stderrdata = await proc.communicate() - self.assertFalse(stdoutdata) - self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata) + self.assertFalse(stdoutdata) + self.assertEqual( + b'failed to create somefile.pub, file exists.\n', + stderrdata) - # And that it exited w/ the correct code - self.assertEqual(proc.returncode, 2) + # And that it exited w/ the correct code + self.assertEqual(proc.returncode, 2) class TestNoiseFowarder(unittest.TestCase): def setUp(self):