diff --git a/README.md b/README.md index c1f66e7..13748fb 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,64 @@ port blocking, or allow more custom routing and execution. It is designed so that in the future, it could support forwarding stderr separately, but also out of band messages, such as window change information, so that a full tty could be forwarded over the -connection. +connection. As transporting a protocol like ssh does all of this +for you, it is unlikely to be expanded to support it. + +Usage +----- + +Included is a sample program that can be used to forward stream +connections to one of the specified hosts and port. + +First install the dependencies. Note that often it's best to install +it into a virtual environment[venv], so, create one: +``` +python -m venv dirname +``` + +Where dirname will be a directory that will be created w/ the virtual +environment. Start the virtual environment (if you're using bash/zsh, +other shells are documented in the link above): +``` +source ./dirname/bin/activate +``` + +Install wsfwd: +``` +pip install git+https://www.funkthat.com/gitea/jmg/wsfwd +``` + +Run the server (for example to forward to the local machine's sshd): +``` +wsfwd serve 127.0.0.1:22 +``` + +Multiple `ip:port` maybe specified on the command to allow forwarding +to multiple hosts. + +This will start up a webserver. It uses hypercorn, so, all the +hypercorn command line arguments may be specified via the `--hypercorn` +argument. For example, to bind only to 127.0.0.1, you can run: +``` +wsfwd serve --hypercorn '--bind 127.0.0.1' 127.0.0.1:22 +``` + +Once that has been setup, the client can be used to connect to the +specified server: +``` +wsfwd connect http://127.0.0.1/connect 127.0.0.1:22 +``` + +This is most useful to be specified to ProxyCommand in your .ssh/config +file: +``` +Host localhost + ProxyCommand /path/to/dirname/bin/wsfwd connect http://127.0.0.1:8000/connect 127.0.0.1:22 +``` + +There are many ways to use it. Another option could also be to launch +it via inetd to transparently pass normal ssh sessions to remote hosts. + Protocol -------- @@ -107,3 +164,5 @@ with other implementations. FastAPI uses starlette: https://www.starlette.io/websockets/ Client: https://github.com/aaugustin/websockets + +[venv]: https://docs.python.org/3/library/venv.html diff --git a/wsfwd/__init__.py b/wsfwd/__init__.py index d59ee73..c622747 100644 --- a/wsfwd/__init__.py +++ b/wsfwd/__init__.py @@ -28,7 +28,13 @@ from unittest.mock import patch, Mock, AsyncMock PIPE = object() -__all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ] +__all__ = [ + 'create_conn_server', + 'HandleConnectLimited', + 'WSFWDServer', + 'WSFWDClient', + 'WFProcess', +] def _debprint(*args): # pragma: no cover import traceback, sys, os.path @@ -46,7 +52,8 @@ def timeout(timeout): def timeout_wrapper(fun): @functools.wraps(fun) async def wrapper(*args, **kwargs): - return await asyncio.wait_for(fun(*args, **kwargs), timeout) + return await asyncio.wait_for(fun(*args, **kwargs), + timeout) return wrapper @@ -208,7 +215,8 @@ class WFStreamWriter: return self._closed async def _close_task(self): - await self._client.sendcmd(dict(cmd='chanclose', chan=self._stream)) + await self._client.sendcmd(dict(cmd='chanclose', + chan=self._stream)) self._closed_event.set() self._closed_event = None @@ -342,7 +350,8 @@ class WSFWDCommon: senddata = b''.join(data) if senddata: - await self._writer(stream.to_bytes(1, 'big') + senddata) + await self._writer(stream.to_bytes(1, 'big') + + senddata) async def __aenter__(self): return self @@ -458,7 +467,8 @@ class WSFWDClient(WSFWDCommon): rsp = await self.sendcmd(dict(cmd='auth', auth=auth)) if 'error' in rsp: - raise RuntimeError('Got auth error: %s' % repr(rsp['error'])) + raise RuntimeError('Got auth error: %s' % + repr(rsp['error'])) @staticmethod async def _pushdata(writer, data): @@ -474,11 +484,13 @@ class WSFWDClient(WSFWDCommon): self._stdin = WFStreamWriter(self, 1) self._stdout = asyncio.StreamReader() - self._procmsgs[2] = functools.partial(self._pushdata, self._stdout) + self._procmsgs[2] = functools.partial(self._pushdata, + self._stdout) self._proc = None - rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2)) + rsp = await self.sendcmd(dict(cmd='connect', args=args, + stdin=1, stdout=2)) if 'error' in rsp: raise RuntimeError(rsp['error']) @@ -496,11 +508,13 @@ class WSFWDClient(WSFWDCommon): self._stdin = WFStreamWriter(self, 1) self._stdout = asyncio.StreamReader() - self._procmsgs[2] = functools.partial(self._pushdata, self._stdout) + self._procmsgs[2] = functools.partial(self._pushdata, + self._stdout) self._proc = WFProcess(self, self._stdin, self._stdout) - rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1, stdout=2)) + rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1, + stdout=2)) if 'error' in rsp: raise RuntimeError(rsp['error']) @@ -544,7 +558,8 @@ async def run_connect(url, ipport): #_debprint('in', repr(stdin), repr(wtr)) #_debprint('out', repr(rdr), repr(stdout)) - await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout)) + await asyncio.gather(fwd_data(stdin, wtr), + fwd_data(rdr, stdout)) sys.exit(0) @@ -559,6 +574,14 @@ async def run_connect(url, ipport): sys.exit(1) class HandleConnectLimited(WSFWDServer): + ''' + A server that allows connection to the specified list of ip:port's + specified in the arguemnt `limited`. + + It is recommend to use `create_conn_server` directly, instead of + instantiating it yourself. + ''' + def __init__(self, *args, limited, **kwargs): super().__init__(*args, **kwargs) @@ -629,7 +652,19 @@ class HandleConnectLimited(WSFWDServer): async def get_finish_handler(self): return await self._finish_handler.wait() -def create_conn_server(app, args): +def create_conn_server(app, *args): + '''Add a route to app for /connect, that will allow connections + to the provided ip:port. + + Example usage: + ``` + from fastapi import FastAPI + from wsfwd import create_conn_server + app = FastAPI() + create_conn_server(app, '127.0.0.1:22') + ``` + ''' + @app.websocket('/connect') async def connect_ws(webSocket: WebSocket): await webSocket.accept() @@ -653,12 +688,16 @@ def real_main(): help='connect to a socket at the specified URL') parser_connect.add_argument('url', type=str, help='the URL to issue the connect command to') - parser_connect.add_argument('ipport', type=str, help=':') + parser_connect.add_argument('ipport', type=str, + help=':') parser_serve = subparsers.add_parser('serve', help='Serve connection requests to the provided : tuples.') - parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules') - parser_serve.add_argument('ipport', type=str, nargs='+', help='[:]') + parser_serve.add_argument('--hypercorn', type=str, action='append', + default=[], + help='arguments to hypercorn, will be split per standard sh rules') + parser_serve.add_argument('ipport', type=str, nargs='+', + help='[:]') args = parser.parse_args() #print(repr(args), file=sys.__stderr__) @@ -667,13 +706,14 @@ def real_main(): return run_connect(args.url, args.ipport) elif args.subparser_name == 'serve': # make hypercorn args - hypercornargs = sum((shlex.split(x) for x in args.hypercorn), []) + hypercornargs = sum((shlex.split(x) for x in args.hypercorn), + []) # make app global app app = FastAPI() - create_conn_server(app, args.ipport) + create_conn_server(app, *args.ipport) import wsfwd #_debprint(repr(wsfwd.app)) @@ -735,7 +775,7 @@ class TestServer(unittest.IsolatedAsyncioTestCase): connarg = '127.0.0.1:12345' app = FastAPI() - create_conn_server(app, [ connarg ]) + create_conn_server(app, connarg) self.serv_task = asyncio.create_task(serve(app, self.config, shutdown_trigger=self.shutdown_event.wait)) @@ -754,7 +794,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase): async with websockets.unix_connect(self.socketpath, 'ws://foo/connect') as websocket, \ - WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client: + WSFWDClient(wrapper(websocket.recv), + wrapper(websocket.send)) as client: mstdout = AsyncMock() echodata = b'somedata' @@ -778,7 +819,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase): await writer.drain() # that we get our data - self.assertEqual(await reader.read(len(echodata)), echodata) + self.assertEqual(await reader.read(len(echodata)), + echodata) # and that there is no more self.assertEqual(await reader.read(len(echodata)), b'') @@ -896,26 +938,50 @@ class TestMain(unittest.IsolatedAsyncioTestCase): return ret, stdoutvalue def test_mainserver(self): - with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \ + with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, + dict(args=[ ], stderr=fp)), \ self.assertRaises(SystemExit) as context: asyncio.run(real_main()) self.assertEqual(context.exception.code, 2) - with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs: + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', + '127.0.0.1:12345', ])), \ + patch('wsfwd.hypercorn_main') as hcm, \ + patch('wsfwd.create_conn_server') as ccs: real_main() hcm.assert_called_with(sys_args=[ 'wsfwd:app' ]) - self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ]) + self.assertEqual(ccs.mock_calls[0][1][1], + '127.0.0.1:12345') self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) - with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs: + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', + '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), \ + patch('wsfwd.hypercorn_main') as hcm, \ + patch('wsfwd.create_conn_server') as ccs: real_main() - hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ]) + hcm.assert_called_with(sys_args=['-b', '127.0.0.1', + 'wsfwd:app' ]) - self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ]) + self.assertEqual(ccs.mock_calls[0][1][1], + '127.0.0.1:12345') + self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) + + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', + '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', + 'another', ])), \ + patch('wsfwd.hypercorn_main') as hcm, \ + patch('wsfwd.create_conn_server') as ccs: + real_main() + + hcm.assert_called_with(sys_args=['-b', '127.0.0.1', + 'wsfwd:app' ]) + + self.assertEqual(ccs.mock_calls[0][1][1], + '127.0.0.1:12345', 'another') self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) @timeout(2) @@ -972,7 +1038,8 @@ class Test(unittest.IsolatedAsyncioTestCase): return WSFWDClient(self.toclient.get, self.toserver.put) def runFakeServer(self, func): - return asyncio.create_task(func(self.toserver.get, self.toclient.put)) + return asyncio.create_task(func(self.toserver.get, + self.toclient.put)) @timeout(2) @patch('wsfwd.WSFWDCommon.shutdown') @@ -991,7 +1058,8 @@ class Test(unittest.IsolatedAsyncioTestCase): msg = self.decode_cmdmsg(cmd) self.assertEqual(msg['cmd'], 'auth') - await writer(self._encodecmd(dict(resp='auth', id=msg['id'], error='Invalid auth'))) + await writer(self._encodecmd(dict(resp='auth', + id=msg['id'], error='Invalid auth'))) serv_task = self.runFakeServer(fake_server) @@ -1100,7 +1168,8 @@ class Test(unittest.IsolatedAsyncioTestCase): ccmsg = await reader() msg = self.decode_cmdmsg(ccmsg) - self.assertEqual(msg, dict(resp='chanclose', id=1, chan=2)) + self.assertEqual(msg, dict(resp='chanclose', id=1, + chan=2)) # return the exit code await writer(self._encodecmd(dict(cmd='exit', id=2, @@ -1120,9 +1189,11 @@ class Test(unittest.IsolatedAsyncioTestCase): # the when exec fails w/ error it is caught with self.assertRaises(RuntimeError): - await a.exec(args=execargs, stdin=PIPE, stdout=PIPE) + await a.exec(args=execargs, stdin=PIPE, + stdout=PIPE) - proc = await a.exec(args=execargs, stdin=PIPE, stdout=PIPE) + proc = await a.exec(args=execargs, stdin=PIPE, + stdout=PIPE) writer, reader = proc.stdin, proc.stdout @@ -1164,7 +1235,8 @@ class Test(unittest.IsolatedAsyncioTestCase): self.assertFalse(procwaittask.done()) # that the wait_closed - waitclosedtask = asyncio.create_task(writer.wait_closed()) + waitclosedtask = asyncio.create_task( + writer.wait_closed()) # and when allowed to run await asyncio.sleep(0) @@ -1218,19 +1290,23 @@ class Test(unittest.IsolatedAsyncioTestCase): r = json.loads(r[1:]) r['resp'] = r['cmd'] del r['cmd'] - await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8')) + await self.toserver.put(b'\x00' + + json.dumps(r).encode('utf-8')) t = asyncio.create_task(task()) t2 = asyncio.create_task(task()) - async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn: + async with WSFWDCommon(self.toserver.get, + self.toclient.put) as cmn: r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar')) - self.assertEqual(r, dict(resp='somecmd', id=1, foo='bar')) + self.assertEqual(r, dict(resp='somecmd', id=1, + foo='bar')) r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar')) - self.assertEqual(r, dict(resp='somecmd', id=2, foo='bar')) + self.assertEqual(r, dict(resp='somecmd', id=2, + foo='bar')) await t await t2 @@ -1247,20 +1323,24 @@ class Test(unittest.IsolatedAsyncioTestCase): r = json.loads(r[1:]) r['resp'] = r['cmd'] del r['cmd'] - await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8')) + await self.toserver.put(b'\x00' + + json.dumps(r).encode('utf-8')) t = asyncio.create_task(task()) t2 = asyncio.create_task(task()) - async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn: + async with WSFWDCommon(self.toserver.get, + self.toclient.put) as cmn: try: - r = asyncio.create_task(cmn.sendcmd(dict(cmd='somecmd', foo='bar'))) + r = asyncio.create_task(cmn.sendcmd(dict( + cmd='somecmd', foo='bar'))) loop = asyncio.get_running_loop() loop.call_later(.1, ev.set) # make sure that we can schedule a second one - await cmn.sendcmd(dict(cmd='somecmd', foo='bar')) + await cmn.sendcmd(dict(cmd='somecmd', + foo='bar')) finally: await r