From 7f7f187bccb618d35ca367301b669fe18f5af84f Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Tue, 12 Apr 2022 16:45:58 -0700 Subject: [PATCH] add support for server and client to do forwarding to TCP streams... This will be documented in a followup commit... --- setup.py | 5 + wsfwd/__init__.py | 524 +++++++++++++++++++++++++++++++++++++++++++++- wsfwd/__main__.py | 9 + 3 files changed, 531 insertions(+), 7 deletions(-) create mode 100644 wsfwd/__main__.py diff --git a/setup.py b/setup.py index 844c3a4..59a36d4 100644 --- a/setup.py +++ b/setup.py @@ -18,12 +18,17 @@ setup( #download_url='', long_description=open('README.md').read(), install_requires=[ + 'aioconsole', # for aioconsole.stream only + 'fastapi', + 'hypercorn', + 'websockets', ], extras_require = { 'dev': [ 'coverage' ], }, entry_points={ 'console_scripts': [ + 'wsfwd = wsfwd.__main__:main', ] } ) diff --git a/wsfwd/__init__.py b/wsfwd/__init__.py index 8ba15f5..d59ee73 100644 --- a/wsfwd/__init__.py +++ b/wsfwd/__init__.py @@ -1,13 +1,28 @@ +import aioconsole.stream +import argparse import asyncio import contextlib import functools +import io import itertools import json +import os +import shlex +import shutil +import sys +import tempfile import unittest +import websockets import traceback from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.websockets import WebSocket +from hypercorn.__main__ import main as hypercorn_main +from hypercorn.config import Config +from hypercorn.asyncio import serve +from io import TextIOWrapper from typing import Dict, Any from unittest.mock import patch, Mock, AsyncMock @@ -15,6 +30,18 @@ PIPE = object() __all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ] +def _debprint(*args): # pragma: no cover + import traceback, sys, os.path + st = traceback.extract_stack(limit=2)[0] + + sep = '' + if args: + sep = ':' + + print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep), + *args, file=sys.stderr) + sys.stderr.flush() + def timeout(timeout): def timeout_wrapper(fun): @functools.wraps(fun) @@ -25,13 +52,25 @@ def timeout(timeout): return timeout_wrapper -def _tbprinter(fun): #pragma: no cover +def _atbprinter(fun): #pragma: no cover @functools.wraps(fun) async def wrapper(*args, **kwargs): try: return await fun(*args, **kwargs) except Exception: - print('in tbprinter:', repr(fun)) + _debprint('in atbprinter:', repr(fun)) + traceback.print_exc() + raise + + return wrapper + +def _tbprinter(fun): #pragma: no cover + @functools.wraps(fun) + def wrapper(*args, **kwargs): + try: + return fun(*args, **kwargs) + except Exception: + _debprint('in tbprinter:', repr(fun)) traceback.print_exc() raise @@ -337,6 +376,8 @@ class WSFWDCommon: self._waitingresp[cmd['id']] = fut + #_debprint('sendcmd:', repr(self), repr(cmd)) + await self._sendcmd(cmd) return await fut @@ -347,13 +388,14 @@ class WSFWDCommon: can be used for either commands or response to commands. ''' + #_debprint('_sendcmd:', repr(self), repr(self._writer), repr(cmd)) await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8')) async def _proccmds(self, msg): '''Interal routine for dispatching commands.''' msg = json.loads(msg) - #print('_proccmds:', repr(self), repr(msg)) + #_debprint('_proccmds:', repr(self), repr(msg)) if 'cmd' in msg: handler = getattr(self, 'handle_%s' % msg['cmd']) msgid = msg['id'] @@ -367,11 +409,13 @@ class WSFWDCommon: error=e.args[0]) # assert resp['id'] == msgid - #print('response:', repr(resp)) + #_debprint('response:', repr(resp)) await self._sendcmd(resp) + #_debprint('sent:', repr(resp)) elif 'resp' in msg: - fut = self._waitingresp.pop(msg['id']) + #_debprint('got resp:', repr(self), repr(msg)) + fut = self._waitingresp.pop(msg['id']) fut.set_result(msg) class WSFWDServer(WSFWDCommon): @@ -420,6 +464,27 @@ class WSFWDClient(WSFWDCommon): async def _pushdata(writer, data): writer.feed_data(data) + async def connect(self, args): + ''' + Returns a StreamReader, StreamWriter pair. + ''' + + # get the stdin/stdout setup + + self._stdin = WFStreamWriter(self, 1) + self._stdout = asyncio.StreamReader() + + self._procmsgs[2] = functools.partial(self._pushdata, self._stdout) + + self._proc = None + + rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2)) + + if 'error' in rsp: + raise RuntimeError(rsp['error']) + + return self._stdout, self._stdin + async def exec(self, args, stdin=PIPE, stdout=PIPE): ''' Returns a WFProcess instance. WFProcess is very similar @@ -442,6 +507,453 @@ class WSFWDClient(WSFWDCommon): return self._proc +def convert_to_ws(url): + if url.startswith('http://'): + url = url.replace('http', 'ws', 1) + elif url.startswith('https://'): + url = url.replace('https', 'wss', 1) + + return url + +# how to do stdin/stdout via async: +# https://github.com/vxgmichel/aioconsole/blob/master/aioconsole/stream.py#L130 +async def fwd_data(reader, writer): + while True: + data = await reader.read(16384) + if data == b'': + #_debprint('fwd_data eof', repr(reader), repr(writer)) + writer.close() + await writer.wait_closed() + #_debprint('fwd_data done', repr(reader), repr(writer)) + return + + #_debprint('fwd_data data', repr(reader), repr(writer), len(data)) + writer.write(data) + + await writer.drain() + +async def run_connect(url, ipport): + url = convert_to_ws(url) + stdin, stdout = await aioconsole.stream.get_standard_streams() + + async with websockets.connect(url) as ws, WSFWDClient(ws.recv, + ws.send) as client: + try: + rdr, wtr = await client.connect(args=ipport) + + #_debprint('in', repr(stdin), repr(wtr)) + #_debprint('out', repr(rdr), repr(stdout)) + + await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout)) + + sys.exit(0) + + except RuntimeError as e: + print('failed to exec: %s' % e.args) + + # not a fan of this, shouldn't be needed, but + # how tests are run w/ runAsyncMain, it is + # required here. + sys.stdout.flush() + + sys.exit(1) + +class HandleConnectLimited(WSFWDServer): + def __init__(self, *args, limited, **kwargs): + super().__init__(*args, **kwargs) + + self._limited = limited + self._did_connect = False + self._finish_handler = asyncio.Event() + + async def shutdown(self): + pass + + def _validatearg(self, ipport): + if ipport in self._limited: + return + + raise RuntimeError('not allowed ipport: %s' % repr(ipport)) + + async def process_writer(self, data): + wrt = self.__writer + wrt.write(data) + await wrt.drain() + + async def process_reader(self): + rdr = self.__reader + stream = self._reader_stream + + try: + while True: + data = await rdr.read(16384) + if not data: + break + self.sendstream(stream, data) + await self.drain(stream) + finally: + await self.sendcmd(dict(cmd='chanclose', chan=stream)) + + async def handle_chanclose(self, msg): + self.clear_stream_handler(self._writer_stream) + self.__writer.close() + await self.__writer.wait_closed() + self._writer_event.set() + + async def handle_connect(self, msg): + if self._did_connect: + raise RuntimeError('already did connect') + + ipport = msg['args'] + + self._validatearg(ipport) + + host, port = ipport.split(':') + port = int(port) + + self.__reader, self.__writer = await asyncio.open_connection( + host=host, port=port) + + self._did_connect = True + + self._writer_stream = msg['stdin'] + self._reader_stream = msg['stdout'] + + # handle writer + self._writer_event = asyncio.Event() + self.add_stream_handler(msg['stdin'], self.process_writer) + + # handle reader + self._reader_task = asyncio.create_task(self.process_reader()) + + async def get_finish_handler(self): + return await self._finish_handler.wait() + +def create_conn_server(app, args): + @app.websocket('/connect') + async def connect_ws(webSocket: WebSocket): + await webSocket.accept() + try: + #_debprint('ccs:', repr(webSocket), repr(webSocket.receive_bytes), repr(webSocket.send_bytes)) + async with HandleConnectLimited(webSocket.receive_bytes, + webSocket.send_bytes, limited=args) as server: + await server.get_finish_handler() + finally: + await webSocket.close() + +@_tbprinter +def real_main(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(title='subcommands', + dest='subparser_name', + description='valid subcommands', help='additional help') + + parser_connect = subparsers.add_parser('connect', + help='connect to a socket at the specified URL') + parser_connect.add_argument('url', type=str, + help='the URL to issue the connect command to') + parser_connect.add_argument('ipport', type=str, help=':') + + parser_serve = subparsers.add_parser('serve', + help='Serve connection requests to the provided : tuples.') + parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules') + parser_serve.add_argument('ipport', type=str, nargs='+', help='[:]') + + args = parser.parse_args() + #print(repr(args), file=sys.__stderr__) + + if args.subparser_name == 'connect': + return run_connect(args.url, args.ipport) + elif args.subparser_name == 'serve': + # make hypercorn args + hypercornargs = sum((shlex.split(x) for x in args.hypercorn), []) + + # make app + global app + + app = FastAPI() + create_conn_server(app, args.ipport) + + import wsfwd + #_debprint(repr(wsfwd.app)) + hypercornargs.append('%s:app' % __name__) + hypercorn_main(sys_args=hypercornargs) + + return + + parser.print_usage() + async def fun(): + sys.exit(5) + + return fun() + +def wrap_connect(mockobj, readdata=b''): + reader = Mock() + reader.read = AsyncMock() + reader.read.side_effect = [ readdata, b'' ] + + writer = Mock() + writer.drain = AsyncMock() + writer.wait_closed = AsyncMock() + + mockobj.return_value = (reader, writer) + +class TestCTW(unittest.TestCase): + def test_convert_to_ws(self): + self.assertEqual(convert_to_ws('http://foo/'), 'ws://foo/') + self.assertEqual(convert_to_ws('https://foo/'), 'wss://foo/') + +class TestServer(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + d = os.path.realpath(tempfile.mkdtemp()) + self.basetempdir = d + + self.shutdown_event = asyncio.Event() + + self.socketpath = os.path.join(self.basetempdir, 'wstest.sock') + + config = Config() + config.graceful_timeout = .01 + config.bind = [ 'unix:' + self.socketpath ] + config.loglevel = 'ERROR' + self.config = config + + async def asyncTearDown(self): + self.app = None + + self.shutdown_event.set() + + await self.serv_task + + shutil.rmtree(self.basetempdir) + self.basetempdir = None + + @patch('asyncio.open_connection') + @timeout(2) + async def test_connect(self, oc): + connarg = '127.0.0.1:12345' + + app = FastAPI() + create_conn_server(app, [ connarg ]) + + self.serv_task = asyncio.create_task(serve(app, self.config, + shutdown_trigger=self.shutdown_event.wait)) + + # get the unix domain socket connected + # need a startup_trigger + await asyncio.sleep(.01) + + def wrapper(corofun): + async def foo(*args, **kwargs): + r = await corofun(*args, **kwargs) + #print('foo:', repr(corofun), repr((args, kwargs)), repr(r)) + return r + + return foo + + async with websockets.unix_connect(self.socketpath, + 'ws://foo/connect') as websocket, \ + WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client: + mstdout = AsyncMock() + + echodata = b'somedata' + wrap_connect(oc, readdata=echodata) + + # that bad args are rejected + with self.assertRaises(RuntimeError): + await client.connect('192.168.0.1:12345') + + with self.assertRaises(RuntimeError): + await client.connect('127.0.0.1:348') + + client.add_stream_handler(2, mstdout) + reader, writer = await client.connect(connarg) + + # that it cannot be connected'd a second time + with self.assertRaises(RuntimeError): + await client.connect(connarg) + + writer.write(echodata) + await writer.drain() + + # that we get our data + self.assertEqual(await reader.read(len(echodata)), echodata) + + # and that there is no more + self.assertEqual(await reader.read(len(echodata)), b'') + + # and we are truly at EOF + self.assertTrue(reader.at_eof()) + + writer.close() + await writer.wait_closed() + + # make sure that it was called w/ the correct arugments + oc.assert_called_with(host='127.0.0.1', port=12345) + + # spin things, not sure best way to handle this + await asyncio.sleep(.01) + + # make sure that the writer was closed + oc.return_value[1].close.assert_called_with() + +class TestMain(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.toclient = asyncio.Queue() + self.toserver = asyncio.Queue() + + async def asyncTearDown(self): + self.assertTrue(self.toclient.empty()) + self.assertTrue(self.toserver.empty()) + + def setup_websockets_mock(self, webcon): + conobj = Mock() + + webcon().__aenter__.return_value = conobj + + conobj.send = self.toserver.put + conobj.recv = self.toclient.get + + webcon.reset_mock() + + @contextlib.contextmanager + def make_pipe(self): + r, w = os.pipe() + with os.fdopen(r, 'rb', buffering=65536) as readfl, \ + os.fdopen(w, 'wb', buffering=65536) as writefl: + yield readfl, writefl + + # too lazy to make this async since async file-like objects + # aren't standard yet in Python + def copytask(self, reader, writer, doclose=True): + while True: + data = reader.read(16384) + #_debprint('ct', repr(reader), repr(writer), doclose, len(data)) + if not data: + if doclose: + writer.close() + #_debprint('ct done', repr(reader), repr(writer), doclose) + return + + writer.write(data) + + @_atbprinter + async def runAsyncMain(self, fun=real_main, stdin=''): + # make stdin bytes + if isinstance(stdin, str): + stdin = stdin.encode() + + # make stdout + stdout = io.BytesIO() + + # Data path: + # stdin -> stdin_task -> stdinwriter -> pipe -> + # stdinreader -> sys.stdin -> real_main -> sys.stdout -> + # stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout + # + # How things get closed: + # stdin is already "closed", stdin_task will close + # stdinwriter when eof encountered (doclose). + + # create the pipes needed + with self.make_pipe() as (stdinreader, stdinwriter), \ + self.make_pipe() as (stdoutreader, stdoutwriter): + # setup the threads to move data + loop = asyncio.get_running_loop() + stdin_task = loop.run_in_executor(None, self.copytask, + io.BytesIO(stdin), stdinwriter) + + # do not close stdout, otherwise we cannot call + # getvalue on BytesIO object + stdout_task = loop.run_in_executor(None, self.copytask, + stdoutreader, stdout, False) + + #_debprint('ram', repr(stdout_task), repr(stdoutreader)) + + # insert the pipes + with patch.dict(sys.__dict__, + dict(stdin=TextIOWrapper(stdinreader), + stdout=TextIOWrapper(stdoutwriter))): + try: + # run the function + await fun() + #await asyncio.wait_for(fun(), 1) + ret = 0 #pragma: no cover + except SystemExit as e: + #_debprint('exit', repr(e.code)) + ret = e.code + + # No one to write anything anymore + # close so stdout_task will end + stdoutwriter.close() + + # make sure all the data has been copied + await asyncio.gather(stdin_task, stdout_task) + + stdoutvalue = stdout.getvalue() + + return ret, stdoutvalue + + def test_mainserver(self): + with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \ + self.assertRaises(SystemExit) as context: + asyncio.run(real_main()) + + self.assertEqual(context.exception.code, 2) + + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs: + real_main() + + hcm.assert_called_with(sys_args=[ 'wsfwd:app' ]) + + self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ]) + self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) + + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs: + real_main() + + hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ]) + + self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ]) + self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) + + @timeout(2) + async def test_socket(self): + class TestServer(WSFWDCommon): + async def echo_handler(self, stream, msg): + self.sendstream(stream, msg) + await self.drain(stream) + + async def handle_chanclose(self, msg): + self.add_tasks(asyncio.create_task(self.sendcmd( + dict(cmd='chanclose', + chan=self._stdout_stream)))) + + async def handle_connect(self, msg): + self._stdout_stream = msg['stdout'] + assert msg['args'] == '127.0.0.1:38493' + self.add_stream_handler(msg['stdin'], + functools.partial(self.echo_handler, + msg['stdout'])) + + server = TestServer(self.toserver.get, self.toclient.put) + + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'connect', + 'https://example.com/connectws', '127.0.0.1:38493', ])), \ + patch('websockets.connect') as webcon: + self.setup_websockets_mock(webcon) + + inpdata = bytes(range(0, 255)) + + ret, stdout = await self.runAsyncMain(stdin=inpdata) + + await server.__aexit__(None, None, None) + + self.assertEqual(stdout, inpdata) + + self.assertEqual(ret, 0) + class Test(unittest.IsolatedAsyncioTestCase): @staticmethod def _encodecmd(payload): @@ -474,7 +986,6 @@ class Test(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_authfail(self): - @_tbprinter async def fake_server(reader, writer): cmd = await reader() @@ -508,7 +1019,6 @@ class Test(unittest.IsolatedAsyncioTestCase): writerdata2 = b'oiwuersldkj' readerdata1 = b'weoiusdofiuwe' - @_tbprinter async def fake_server(reader, writer): # get the auth command auth = await reader() diff --git a/wsfwd/__main__.py b/wsfwd/__main__.py new file mode 100644 index 0000000..f7bf6fc --- /dev/null +++ b/wsfwd/__main__.py @@ -0,0 +1,9 @@ +import asyncio + +from wsfwd import real_main + +def main(): + asyncio.run(real_main()) + +if __name__ == '__main__': #pragma: no cover + main()