| @@ -17,6 +17,19 @@ def timeout(timeout): | |||||
| return timeout_wrapper | return timeout_wrapper | ||||
| def _tbprinter(fun): #pragma: no cover | |||||
| @functools.wraps(fun) | |||||
| async def wrapper(*args, **kwargs): | |||||
| try: | |||||
| return await fun(*args, **kwargs) | |||||
| except Exception: | |||||
| import traceback | |||||
| print('in tbprinter:', repr(fun)) | |||||
| traceback.print_exc() | |||||
| raise | |||||
| return wrapper | |||||
| class TestTimeout(unittest.IsolatedAsyncioTestCase): | class TestTimeout(unittest.IsolatedAsyncioTestCase): | ||||
| async def test_timeout(self): | async def test_timeout(self): | ||||
| @timeout(.001) | @timeout(.001) | ||||
| @@ -112,7 +125,74 @@ class WFStreamWriter: | |||||
| async def wait_closed(self): | async def wait_closed(self): | ||||
| pass | pass | ||||
| class WSFWDClient: | |||||
| class WSFWDCommon: | |||||
| def __init__(self, reader, writer): | |||||
| self._reader = reader | |||||
| self._writer = writer | |||||
| self._task = asyncio.create_task(self._process_msgs()) | |||||
| # this contains enqueued outbound data for draining | |||||
| self._streams = dict() | |||||
| # dispatch incoming messages | |||||
| self._procmsgs = dict() | |||||
| def add_stream_handler(self, stream, hndlr): | |||||
| # XXX - make sure we don't overwrite an existing one | |||||
| self._procmsgs[stream] = hndlr | |||||
| async def _process_msgs(self): | |||||
| while True: | |||||
| msg = await self._reader() | |||||
| #print('got:', repr(msg)) | |||||
| stream = msg[0] | |||||
| await self._procmsgs[stream](msg[1:]) | |||||
| def sendstream(self, stream, *data): | |||||
| if not all(isinstance(x, bytes) for x in data): | |||||
| raise ValueError('write data must be bytes') | |||||
| self._streams.setdefault(stream, []).extend(data) | |||||
| async def drain(self, stream): | |||||
| datalist = self._streams[stream] | |||||
| data = datalist.copy() | |||||
| # clear the data | |||||
| datalist[:] = [] | |||||
| senddata = b''.join(data) | |||||
| if senddata: | |||||
| await self._writer(stream.to_bytes(1, 'big') + senddata) | |||||
| async def __aenter__(self): | |||||
| return self | |||||
| async def __aexit__(self, exc_type, exc, tb): | |||||
| self._task.cancel() | |||||
| async def _sendcmd(self, cmd): | |||||
| await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) | |||||
| class WSFWDServer(WSFWDCommon): | |||||
| def __init__(self, reader, writer): | |||||
| super().__init__(reader, writer) | |||||
| self.add_stream_handler(0, self._proccmds) | |||||
| async def _proccmds(self, msg): | |||||
| msg = json.loads(msg) | |||||
| handler = getattr(self, 'handle_%s' % msg['cmd']) | |||||
| try: | |||||
| resp = await handler(msg) | |||||
| resp = dict(resp=msg['cmd']) | |||||
| except RuntimeError as e: | |||||
| resp = dict(resp=msg['cmd'], error=e.args[0]) | |||||
| await self._sendcmd(resp) | |||||
| class WSFWDClient(WSFWDCommon): | |||||
| def __init__(self, reader, writer): | def __init__(self, reader, writer): | ||||
| '''This is the client for doing command execution over | '''This is the client for doing command execution over | ||||
| a datagram protocol, such as WebSockets. The two | a datagram protocol, such as WebSockets. The two | ||||
| @@ -125,29 +205,17 @@ class WSFWDClient: | |||||
| be sent to the server. | be sent to the server. | ||||
| ''' | ''' | ||||
| self._reader = reader | |||||
| self._writer = writer | |||||
| self._task = asyncio.create_task(self._process_msgs()) | |||||
| # this contains enqueued outbound data for draining | |||||
| self._streams = dict() | |||||
| super().__init__(reader, writer) | |||||
| self._cmdq = asyncio.Queue() | self._cmdq = asyncio.Queue() | ||||
| # dispatch incoming messages | |||||
| self._procmsgs = dict() | |||||
| self._procmsgs[0] = self._process_cmdmsg | self._procmsgs[0] = self._process_cmdmsg | ||||
| # wait for msgs | # wait for msgs | ||||
| self._fetchpend = dict() | self._fetchpend = dict() | ||||
| self._fetchpend[0] = self._cmdq.get | self._fetchpend[0] = self._cmdq.get | ||||
| async def __aenter__(self): | |||||
| return self | |||||
| async def __aexit__(self, exc_type, exc, tb): | |||||
| self._task.cancel() | |||||
| async def _process_msgs(self): | async def _process_msgs(self): | ||||
| while True: | while True: | ||||
| msg = await self._reader() | msg = await self._reader() | ||||
| @@ -166,29 +234,8 @@ class WSFWDClient: | |||||
| await self._cmdq.put(msg) | await self._cmdq.put(msg) | ||||
| async def _sendcmd(self, cmd): | |||||
| await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) | |||||
| def sendstream(self, stream, *data): | |||||
| if not all(isinstance(x, bytes) for x in data): | |||||
| raise ValueError('write data must be bytes') | |||||
| self._streams.setdefault(stream, []).extend(data) | |||||
| async def drain(self, stream): | |||||
| datalist = self._streams[stream] | |||||
| data = datalist.copy() | |||||
| # clear the data | |||||
| datalist[:] = [] | |||||
| senddata = b''.join(data) | |||||
| if senddata: | |||||
| await self._writer(stream.to_bytes(1, 'big') + senddata) | |||||
| async def auth(self, auth): | async def auth(self, auth): | ||||
| await self._sendcmd(dict(auth=auth)) | |||||
| await self._sendcmd(dict(cmd='auth', auth=auth)) | |||||
| rsp = await self._fetchpend[0]() | rsp = await self._fetchpend[0]() | ||||
| @@ -247,10 +294,9 @@ class Test(unittest.IsolatedAsyncioTestCase): | |||||
| serv_task = self.runFakeServer(fake_server) | serv_task = self.runFakeServer(fake_server) | ||||
| a = self.runClient() | |||||
| with self.assertRaises(RuntimeError): | |||||
| await a.auth('randomtoken') | |||||
| async with self.runClient() as a: | |||||
| with self.assertRaises(RuntimeError): | |||||
| await a.auth('randomtoken') | |||||
| await serv_task | await serv_task | ||||
| @@ -261,7 +307,7 @@ class Test(unittest.IsolatedAsyncioTestCase): | |||||
| token = 'sdlfkjsoidfjl' | token = 'sdlfkjsoidfjl' | ||||
| authdict = dict(bearer=token) | authdict = dict(bearer=token) | ||||
| authmsg = { 'auth': authdict } | |||||
| authmsg = { 'cmd': 'auth', 'auth': authdict } | |||||
| execargs = [ 'sldkfj', 'oweijfls' ] | execargs = [ 'sldkfj', 'oweijfls' ] | ||||
| writerdata1 = b'asoijeflksdjf' | writerdata1 = b'asoijeflksdjf' | ||||
| writerdata2 = b'oiwuersldkj' | writerdata2 = b'oiwuersldkj' | ||||
| @@ -417,3 +463,73 @@ class Test(unittest.IsolatedAsyncioTestCase): | |||||
| # That the process task does get terminated | # That the process task does get terminated | ||||
| self.assertTrue(proctask.done()) | self.assertTrue(proctask.done()) | ||||
| @timeout(2) | |||||
| async def test_client_server(self): | |||||
| token = 'sdlfkjoijef' | |||||
| authdict = dict(bearer=token) | |||||
| badauthdict = dict(bearer=token + 'eofij') | |||||
| class TestWSFDServer(WSFWDServer): | |||||
| async def handle_auth(self, msg): | |||||
| if msg['auth'] == authdict: | |||||
| return | |||||
| raise RuntimeError('Invalid auth') | |||||
| async def echo_handler(self, stream, msg): | |||||
| self.sendstream(stream, msg) | |||||
| await self.drain(stream) | |||||
| async def handle_exec(self, msg): | |||||
| self.add_stream_handler(msg['stdin'], | |||||
| functools.partial(self.echo_handler, | |||||
| msg['stdout'])) | |||||
| async def closehandler(ccmsg): | |||||
| if ccmsg['chan'] == msg['stdin']: | |||||
| await self.drain(msg['stdout']) | |||||
| await self._sendcmd(dict(cmd='chanclose', chan=msg['stdout'])) | |||||
| await self._sendcmd(dict(cmd='exit', code=0)) | |||||
| # XXX - not happy w/ this api | |||||
| self.handle_chanclose = closehandler | |||||
| server = TestWSFDServer(self.toserver.get, self.toclient.put) | |||||
| async with self.runClient() as a: | |||||
| with self.assertRaises(RuntimeError) as cm: | |||||
| await a.auth(badauthdict) | |||||
| re = cm.exception | |||||
| self.assertEqual(re.args[0], "Got auth error: 'Invalid auth'") | |||||
| await a.auth(authdict) | |||||
| proc = await a.exec(args=['echo']) | |||||
| stdin, stdout = proc.stdin, proc.stdout | |||||
| data = b'seoiadsujflaksdfj' | |||||
| stdin.write(data) | |||||
| await stdin.drain() | |||||
| r = await stdout.read(len(data)) | |||||
| self.assertEqual(r, data) | |||||
| data = data[::-1] # reverse data | |||||
| stdin.write(data) | |||||
| await stdin.drain() | |||||
| r = await stdout.read(len(data)) | |||||
| self.assertEqual(r, data) | |||||
| stdin.close() | |||||
| await stdin.wait_closed() | |||||
| self.assertEqual(await stdout.read(), b'') | |||||
| self.assertEqual(await proc.wait(), 0) | |||||