| @@ -17,6 +17,19 @@ def timeout(timeout): | |||
| return timeout_wrapper | |||
| def _tbprinter(fun): #pragma: no cover | |||
| @functools.wraps(fun) | |||
| async def wrapper(*args, **kwargs): | |||
| try: | |||
| return await fun(*args, **kwargs) | |||
| except Exception: | |||
| import traceback | |||
| print('in tbprinter:', repr(fun)) | |||
| traceback.print_exc() | |||
| raise | |||
| return wrapper | |||
| class TestTimeout(unittest.IsolatedAsyncioTestCase): | |||
| async def test_timeout(self): | |||
| @timeout(.001) | |||
| @@ -112,7 +125,74 @@ class WFStreamWriter: | |||
| async def wait_closed(self): | |||
| pass | |||
| class WSFWDClient: | |||
| class WSFWDCommon: | |||
| def __init__(self, reader, writer): | |||
| self._reader = reader | |||
| self._writer = writer | |||
| self._task = asyncio.create_task(self._process_msgs()) | |||
| # this contains enqueued outbound data for draining | |||
| self._streams = dict() | |||
| # dispatch incoming messages | |||
| self._procmsgs = dict() | |||
| def add_stream_handler(self, stream, hndlr): | |||
| # XXX - make sure we don't overwrite an existing one | |||
| self._procmsgs[stream] = hndlr | |||
| async def _process_msgs(self): | |||
| while True: | |||
| msg = await self._reader() | |||
| #print('got:', repr(msg)) | |||
| stream = msg[0] | |||
| await self._procmsgs[stream](msg[1:]) | |||
| def sendstream(self, stream, *data): | |||
| if not all(isinstance(x, bytes) for x in data): | |||
| raise ValueError('write data must be bytes') | |||
| self._streams.setdefault(stream, []).extend(data) | |||
| async def drain(self, stream): | |||
| datalist = self._streams[stream] | |||
| data = datalist.copy() | |||
| # clear the data | |||
| datalist[:] = [] | |||
| senddata = b''.join(data) | |||
| if senddata: | |||
| await self._writer(stream.to_bytes(1, 'big') + senddata) | |||
| async def __aenter__(self): | |||
| return self | |||
| async def __aexit__(self, exc_type, exc, tb): | |||
| self._task.cancel() | |||
| async def _sendcmd(self, cmd): | |||
| await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) | |||
| class WSFWDServer(WSFWDCommon): | |||
| def __init__(self, reader, writer): | |||
| super().__init__(reader, writer) | |||
| self.add_stream_handler(0, self._proccmds) | |||
| async def _proccmds(self, msg): | |||
| msg = json.loads(msg) | |||
| handler = getattr(self, 'handle_%s' % msg['cmd']) | |||
| try: | |||
| resp = await handler(msg) | |||
| resp = dict(resp=msg['cmd']) | |||
| except RuntimeError as e: | |||
| resp = dict(resp=msg['cmd'], error=e.args[0]) | |||
| await self._sendcmd(resp) | |||
| class WSFWDClient(WSFWDCommon): | |||
| def __init__(self, reader, writer): | |||
| '''This is the client for doing command execution over | |||
| a datagram protocol, such as WebSockets. The two | |||
| @@ -125,29 +205,17 @@ class WSFWDClient: | |||
| be sent to the server. | |||
| ''' | |||
| self._reader = reader | |||
| self._writer = writer | |||
| self._task = asyncio.create_task(self._process_msgs()) | |||
| # this contains enqueued outbound data for draining | |||
| self._streams = dict() | |||
| super().__init__(reader, writer) | |||
| self._cmdq = asyncio.Queue() | |||
| # dispatch incoming messages | |||
| self._procmsgs = dict() | |||
| self._procmsgs[0] = self._process_cmdmsg | |||
| # wait for msgs | |||
| self._fetchpend = dict() | |||
| self._fetchpend[0] = self._cmdq.get | |||
| async def __aenter__(self): | |||
| return self | |||
| async def __aexit__(self, exc_type, exc, tb): | |||
| self._task.cancel() | |||
| async def _process_msgs(self): | |||
| while True: | |||
| msg = await self._reader() | |||
| @@ -166,29 +234,8 @@ class WSFWDClient: | |||
| await self._cmdq.put(msg) | |||
| async def _sendcmd(self, cmd): | |||
| await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) | |||
| def sendstream(self, stream, *data): | |||
| if not all(isinstance(x, bytes) for x in data): | |||
| raise ValueError('write data must be bytes') | |||
| self._streams.setdefault(stream, []).extend(data) | |||
| async def drain(self, stream): | |||
| datalist = self._streams[stream] | |||
| data = datalist.copy() | |||
| # clear the data | |||
| datalist[:] = [] | |||
| senddata = b''.join(data) | |||
| if senddata: | |||
| await self._writer(stream.to_bytes(1, 'big') + senddata) | |||
| async def auth(self, auth): | |||
| await self._sendcmd(dict(auth=auth)) | |||
| await self._sendcmd(dict(cmd='auth', auth=auth)) | |||
| rsp = await self._fetchpend[0]() | |||
| @@ -247,10 +294,9 @@ class Test(unittest.IsolatedAsyncioTestCase): | |||
| serv_task = self.runFakeServer(fake_server) | |||
| a = self.runClient() | |||
| with self.assertRaises(RuntimeError): | |||
| await a.auth('randomtoken') | |||
| async with self.runClient() as a: | |||
| with self.assertRaises(RuntimeError): | |||
| await a.auth('randomtoken') | |||
| await serv_task | |||
| @@ -261,7 +307,7 @@ class Test(unittest.IsolatedAsyncioTestCase): | |||
| token = 'sdlfkjsoidfjl' | |||
| authdict = dict(bearer=token) | |||
| authmsg = { 'auth': authdict } | |||
| authmsg = { 'cmd': 'auth', 'auth': authdict } | |||
| execargs = [ 'sldkfj', 'oweijfls' ] | |||
| writerdata1 = b'asoijeflksdjf' | |||
| writerdata2 = b'oiwuersldkj' | |||
| @@ -417,3 +463,73 @@ class Test(unittest.IsolatedAsyncioTestCase): | |||
| # That the process task does get terminated | |||
| self.assertTrue(proctask.done()) | |||
| @timeout(2) | |||
| async def test_client_server(self): | |||
| token = 'sdlfkjoijef' | |||
| authdict = dict(bearer=token) | |||
| badauthdict = dict(bearer=token + 'eofij') | |||
| class TestWSFDServer(WSFWDServer): | |||
| async def handle_auth(self, msg): | |||
| if msg['auth'] == authdict: | |||
| return | |||
| raise RuntimeError('Invalid auth') | |||
| async def echo_handler(self, stream, msg): | |||
| self.sendstream(stream, msg) | |||
| await self.drain(stream) | |||
| async def handle_exec(self, msg): | |||
| self.add_stream_handler(msg['stdin'], | |||
| functools.partial(self.echo_handler, | |||
| msg['stdout'])) | |||
| async def closehandler(ccmsg): | |||
| if ccmsg['chan'] == msg['stdin']: | |||
| await self.drain(msg['stdout']) | |||
| await self._sendcmd(dict(cmd='chanclose', chan=msg['stdout'])) | |||
| await self._sendcmd(dict(cmd='exit', code=0)) | |||
| # XXX - not happy w/ this api | |||
| self.handle_chanclose = closehandler | |||
| server = TestWSFDServer(self.toserver.get, self.toclient.put) | |||
| async with self.runClient() as a: | |||
| with self.assertRaises(RuntimeError) as cm: | |||
| await a.auth(badauthdict) | |||
| re = cm.exception | |||
| self.assertEqual(re.args[0], "Got auth error: 'Invalid auth'") | |||
| await a.auth(authdict) | |||
| proc = await a.exec(args=['echo']) | |||
| stdin, stdout = proc.stdin, proc.stdout | |||
| data = b'seoiadsujflaksdfj' | |||
| stdin.write(data) | |||
| await stdin.drain() | |||
| r = await stdout.read(len(data)) | |||
| self.assertEqual(r, data) | |||
| data = data[::-1] # reverse data | |||
| stdin.write(data) | |||
| await stdin.drain() | |||
| r = await stdout.read(len(data)) | |||
| self.assertEqual(r, data) | |||
| stdin.close() | |||
| await stdin.wait_closed() | |||
| self.assertEqual(await stdout.read(), b'') | |||
| self.assertEqual(await proc.wait(), 0) | |||