From 75ceac2b8309ac4aeb19fe64e5e5b4a17e5040f2 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Sun, 13 Dec 2020 15:22:15 -0800 Subject: [PATCH] implement the server side of things... --- wsfwd/__init__.py | 198 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 157 insertions(+), 41 deletions(-) diff --git a/wsfwd/__init__.py b/wsfwd/__init__.py index f3047f4..438aa78 100644 --- a/wsfwd/__init__.py +++ b/wsfwd/__init__.py @@ -17,6 +17,19 @@ def timeout(timeout): return timeout_wrapper +def _tbprinter(fun): #pragma: no cover + @functools.wraps(fun) + async def wrapper(*args, **kwargs): + try: + return await fun(*args, **kwargs) + except Exception: + import traceback + print('in tbprinter:', repr(fun)) + traceback.print_exc() + raise + + return wrapper + class TestTimeout(unittest.IsolatedAsyncioTestCase): async def test_timeout(self): @timeout(.001) @@ -112,7 +125,74 @@ class WFStreamWriter: async def wait_closed(self): pass -class WSFWDClient: +class WSFWDCommon: + def __init__(self, reader, writer): + self._reader = reader + self._writer = writer + self._task = asyncio.create_task(self._process_msgs()) + + # this contains enqueued outbound data for draining + self._streams = dict() + + # dispatch incoming messages + self._procmsgs = dict() + + def add_stream_handler(self, stream, hndlr): + # XXX - make sure we don't overwrite an existing one + self._procmsgs[stream] = hndlr + + async def _process_msgs(self): + while True: + msg = await self._reader() + #print('got:', repr(msg)) + stream = msg[0] + await self._procmsgs[stream](msg[1:]) + + def sendstream(self, stream, *data): + if not all(isinstance(x, bytes) for x in data): + raise ValueError('write data must be bytes') + + self._streams.setdefault(stream, []).extend(data) + + async def drain(self, stream): + datalist = self._streams[stream] + data = datalist.copy() + + # clear the data + datalist[:] = [] + + senddata = b''.join(data) + + if senddata: + await self._writer(stream.to_bytes(1, 'big') + senddata) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + self._task.cancel() + + async def _sendcmd(self, cmd): + await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) + +class WSFWDServer(WSFWDCommon): + def __init__(self, reader, writer): + super().__init__(reader, writer) + + self.add_stream_handler(0, self._proccmds) + + async def _proccmds(self, msg): + msg = json.loads(msg) + handler = getattr(self, 'handle_%s' % msg['cmd']) + try: + resp = await handler(msg) + resp = dict(resp=msg['cmd']) + except RuntimeError as e: + resp = dict(resp=msg['cmd'], error=e.args[0]) + + await self._sendcmd(resp) + +class WSFWDClient(WSFWDCommon): def __init__(self, reader, writer): '''This is the client for doing command execution over a datagram protocol, such as WebSockets. The two @@ -125,29 +205,17 @@ class WSFWDClient: be sent to the server. ''' - self._reader = reader - self._writer = writer - self._task = asyncio.create_task(self._process_msgs()) - # this contains enqueued outbound data for draining - self._streams = dict() + super().__init__(reader, writer) self._cmdq = asyncio.Queue() - # dispatch incoming messages - self._procmsgs = dict() self._procmsgs[0] = self._process_cmdmsg # wait for msgs self._fetchpend = dict() self._fetchpend[0] = self._cmdq.get - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - self._task.cancel() - async def _process_msgs(self): while True: msg = await self._reader() @@ -166,29 +234,8 @@ class WSFWDClient: await self._cmdq.put(msg) - async def _sendcmd(self, cmd): - await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) - - def sendstream(self, stream, *data): - if not all(isinstance(x, bytes) for x in data): - raise ValueError('write data must be bytes') - - self._streams.setdefault(stream, []).extend(data) - - async def drain(self, stream): - datalist = self._streams[stream] - data = datalist.copy() - - # clear the data - datalist[:] = [] - - senddata = b''.join(data) - - if senddata: - await self._writer(stream.to_bytes(1, 'big') + senddata) - async def auth(self, auth): - await self._sendcmd(dict(auth=auth)) + await self._sendcmd(dict(cmd='auth', auth=auth)) rsp = await self._fetchpend[0]() @@ -247,10 +294,9 @@ class Test(unittest.IsolatedAsyncioTestCase): serv_task = self.runFakeServer(fake_server) - a = self.runClient() - - with self.assertRaises(RuntimeError): - await a.auth('randomtoken') + async with self.runClient() as a: + with self.assertRaises(RuntimeError): + await a.auth('randomtoken') await serv_task @@ -261,7 +307,7 @@ class Test(unittest.IsolatedAsyncioTestCase): token = 'sdlfkjsoidfjl' authdict = dict(bearer=token) - authmsg = { 'auth': authdict } + authmsg = { 'cmd': 'auth', 'auth': authdict } execargs = [ 'sldkfj', 'oweijfls' ] writerdata1 = b'asoijeflksdjf' writerdata2 = b'oiwuersldkj' @@ -417,3 +463,73 @@ class Test(unittest.IsolatedAsyncioTestCase): # That the process task does get terminated self.assertTrue(proctask.done()) + + @timeout(2) + async def test_client_server(self): + token = 'sdlfkjoijef' + authdict = dict(bearer=token) + badauthdict = dict(bearer=token + 'eofij') + + class TestWSFDServer(WSFWDServer): + async def handle_auth(self, msg): + if msg['auth'] == authdict: + return + + raise RuntimeError('Invalid auth') + + async def echo_handler(self, stream, msg): + self.sendstream(stream, msg) + await self.drain(stream) + + async def handle_exec(self, msg): + self.add_stream_handler(msg['stdin'], + functools.partial(self.echo_handler, + msg['stdout'])) + + async def closehandler(ccmsg): + if ccmsg['chan'] == msg['stdin']: + await self.drain(msg['stdout']) + await self._sendcmd(dict(cmd='chanclose', chan=msg['stdout'])) + await self._sendcmd(dict(cmd='exit', code=0)) + + # XXX - not happy w/ this api + self.handle_chanclose = closehandler + + server = TestWSFDServer(self.toserver.get, self.toclient.put) + + async with self.runClient() as a: + with self.assertRaises(RuntimeError) as cm: + await a.auth(badauthdict) + + re = cm.exception + self.assertEqual(re.args[0], "Got auth error: 'Invalid auth'") + + await a.auth(authdict) + + proc = await a.exec(args=['echo']) + + stdin, stdout = proc.stdin, proc.stdout + + data = b'seoiadsujflaksdfj' + stdin.write(data) + await stdin.drain() + + r = await stdout.read(len(data)) + + self.assertEqual(r, data) + + data = data[::-1] # reverse data + + stdin.write(data) + await stdin.drain() + + r = await stdout.read(len(data)) + + self.assertEqual(r, data) + + stdin.close() + await stdin.wait_closed() + + self.assertEqual(await stdout.read(), b'') + + self.assertEqual(await proc.wait(), 0)