Browse Source

add some documentation, wrap a few long lines..

main
John-Mark Gurney 3 years ago
parent
commit
fd2c5f9fba
2 changed files with 179 additions and 40 deletions
  1. +60
    -1
      README.md
  2. +119
    -39
      wsfwd/__init__.py

+ 60
- 1
README.md View File

@@ -9,7 +9,64 @@ port blocking, or allow more custom routing and execution.
It is designed so that in the future, it could support forwarding It is designed so that in the future, it could support forwarding
stderr separately, but also out of band messages, such as window stderr separately, but also out of band messages, such as window
change information, so that a full tty could be forwarded over the change information, so that a full tty could be forwarded over the
connection.
connection. As transporting a protocol like ssh does all of this
for you, it is unlikely to be expanded to support it.

Usage
-----

Included is a sample program that can be used to forward stream
connections to one of the specified hosts and port.

First install the dependencies. Note that often it's best to install
it into a virtual environment[venv], so, create one:
```
python -m venv dirname
```

Where dirname will be a directory that will be created w/ the virtual
environment. Start the virtual environment (if you're using bash/zsh,
other shells are documented in the link above):
```
source ./dirname/bin/activate
```

Install wsfwd:
```
pip install git+https://www.funkthat.com/gitea/jmg/wsfwd
```

Run the server (for example to forward to the local machine's sshd):
```
wsfwd serve 127.0.0.1:22
```

Multiple `ip:port` maybe specified on the command to allow forwarding
to multiple hosts.

This will start up a webserver. It uses hypercorn, so, all the
hypercorn command line arguments may be specified via the `--hypercorn`
argument. For example, to bind only to 127.0.0.1, you can run:
```
wsfwd serve --hypercorn '--bind 127.0.0.1' 127.0.0.1:22
```

Once that has been setup, the client can be used to connect to the
specified server:
```
wsfwd connect http://127.0.0.1/connect 127.0.0.1:22
```

This is most useful to be specified to ProxyCommand in your .ssh/config
file:
```
Host localhost
ProxyCommand /path/to/dirname/bin/wsfwd connect http://127.0.0.1:8000/connect 127.0.0.1:22
```

There are many ways to use it. Another option could also be to launch
it via inetd to transparently pass normal ssh sessions to remote hosts.



Protocol Protocol
-------- --------
@@ -107,3 +164,5 @@ with other implementations.
FastAPI uses starlette: https://www.starlette.io/websockets/ FastAPI uses starlette: https://www.starlette.io/websockets/


Client: https://github.com/aaugustin/websockets Client: https://github.com/aaugustin/websockets

[venv]: https://docs.python.org/3/library/venv.html

+ 119
- 39
wsfwd/__init__.py View File

@@ -28,7 +28,13 @@ from unittest.mock import patch, Mock, AsyncMock


PIPE = object() PIPE = object()


__all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ]
__all__ = [
'create_conn_server',
'HandleConnectLimited',
'WSFWDServer',
'WSFWDClient',
'WFProcess',
]


def _debprint(*args): # pragma: no cover def _debprint(*args): # pragma: no cover
import traceback, sys, os.path import traceback, sys, os.path
@@ -46,7 +52,8 @@ def timeout(timeout):
def timeout_wrapper(fun): def timeout_wrapper(fun):
@functools.wraps(fun) @functools.wraps(fun)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
return await asyncio.wait_for(fun(*args, **kwargs), timeout)
return await asyncio.wait_for(fun(*args, **kwargs),
timeout)


return wrapper return wrapper


@@ -208,7 +215,8 @@ class WFStreamWriter:
return self._closed return self._closed


async def _close_task(self): async def _close_task(self):
await self._client.sendcmd(dict(cmd='chanclose', chan=self._stream))
await self._client.sendcmd(dict(cmd='chanclose',
chan=self._stream))


self._closed_event.set() self._closed_event.set()
self._closed_event = None self._closed_event = None
@@ -342,7 +350,8 @@ class WSFWDCommon:
senddata = b''.join(data) senddata = b''.join(data)


if senddata: if senddata:
await self._writer(stream.to_bytes(1, 'big') + senddata)
await self._writer(stream.to_bytes(1, 'big') +
senddata)


async def __aenter__(self): async def __aenter__(self):
return self return self
@@ -458,7 +467,8 @@ class WSFWDClient(WSFWDCommon):
rsp = await self.sendcmd(dict(cmd='auth', auth=auth)) rsp = await self.sendcmd(dict(cmd='auth', auth=auth))


if 'error' in rsp: if 'error' in rsp:
raise RuntimeError('Got auth error: %s' % repr(rsp['error']))
raise RuntimeError('Got auth error: %s' %
repr(rsp['error']))


@staticmethod @staticmethod
async def _pushdata(writer, data): async def _pushdata(writer, data):
@@ -474,11 +484,13 @@ class WSFWDClient(WSFWDCommon):
self._stdin = WFStreamWriter(self, 1) self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader() self._stdout = asyncio.StreamReader()


self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._procmsgs[2] = functools.partial(self._pushdata,
self._stdout)


self._proc = None self._proc = None


rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2))
rsp = await self.sendcmd(dict(cmd='connect', args=args,
stdin=1, stdout=2))


if 'error' in rsp: if 'error' in rsp:
raise RuntimeError(rsp['error']) raise RuntimeError(rsp['error'])
@@ -496,11 +508,13 @@ class WSFWDClient(WSFWDCommon):
self._stdin = WFStreamWriter(self, 1) self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader() self._stdout = asyncio.StreamReader()


self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._procmsgs[2] = functools.partial(self._pushdata,
self._stdout)


self._proc = WFProcess(self, self._stdin, self._stdout) self._proc = WFProcess(self, self._stdin, self._stdout)


rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1, stdout=2))
rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1,
stdout=2))


if 'error' in rsp: if 'error' in rsp:
raise RuntimeError(rsp['error']) raise RuntimeError(rsp['error'])
@@ -544,7 +558,8 @@ async def run_connect(url, ipport):
#_debprint('in', repr(stdin), repr(wtr)) #_debprint('in', repr(stdin), repr(wtr))
#_debprint('out', repr(rdr), repr(stdout)) #_debprint('out', repr(rdr), repr(stdout))


await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout))
await asyncio.gather(fwd_data(stdin, wtr),
fwd_data(rdr, stdout))


sys.exit(0) sys.exit(0)


@@ -559,6 +574,14 @@ async def run_connect(url, ipport):
sys.exit(1) sys.exit(1)


class HandleConnectLimited(WSFWDServer): class HandleConnectLimited(WSFWDServer):
'''
A server that allows connection to the specified list of ip:port's
specified in the arguemnt `limited`.

It is recommend to use `create_conn_server` directly, instead of
instantiating it yourself.
'''

def __init__(self, *args, limited, **kwargs): def __init__(self, *args, limited, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)


@@ -629,7 +652,19 @@ class HandleConnectLimited(WSFWDServer):
async def get_finish_handler(self): async def get_finish_handler(self):
return await self._finish_handler.wait() return await self._finish_handler.wait()


def create_conn_server(app, args):
def create_conn_server(app, *args):
'''Add a route to app for /connect, that will allow connections
to the provided ip:port.

Example usage:
```
from fastapi import FastAPI
from wsfwd import create_conn_server
app = FastAPI()
create_conn_server(app, '127.0.0.1:22')
```
'''

@app.websocket('/connect') @app.websocket('/connect')
async def connect_ws(webSocket: WebSocket): async def connect_ws(webSocket: WebSocket):
await webSocket.accept() await webSocket.accept()
@@ -653,12 +688,16 @@ def real_main():
help='connect to a socket at the specified URL') help='connect to a socket at the specified URL')
parser_connect.add_argument('url', type=str, parser_connect.add_argument('url', type=str,
help='the URL to issue the connect command to') help='the URL to issue the connect command to')
parser_connect.add_argument('ipport', type=str, help='<ip address>:<port>')
parser_connect.add_argument('ipport', type=str,
help='<ip address>:<port>')


parser_serve = subparsers.add_parser('serve', parser_serve = subparsers.add_parser('serve',
help='Serve connection requests to the provided <ip>:<port> tuples.') help='Serve connection requests to the provided <ip>:<port> tuples.')
parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+', help='<ip address>[:<port>]')
parser_serve.add_argument('--hypercorn', type=str, action='append',
default=[],
help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+',
help='<ip address>[:<port>]')


args = parser.parse_args() args = parser.parse_args()
#print(repr(args), file=sys.__stderr__) #print(repr(args), file=sys.__stderr__)
@@ -667,13 +706,14 @@ def real_main():
return run_connect(args.url, args.ipport) return run_connect(args.url, args.ipport)
elif args.subparser_name == 'serve': elif args.subparser_name == 'serve':
# make hypercorn args # make hypercorn args
hypercornargs = sum((shlex.split(x) for x in args.hypercorn), [])
hypercornargs = sum((shlex.split(x) for x in args.hypercorn),
[])


# make app # make app
global app global app


app = FastAPI() app = FastAPI()
create_conn_server(app, args.ipport)
create_conn_server(app, *args.ipport)


import wsfwd import wsfwd
#_debprint(repr(wsfwd.app)) #_debprint(repr(wsfwd.app))
@@ -735,7 +775,7 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
connarg = '127.0.0.1:12345' connarg = '127.0.0.1:12345'


app = FastAPI() app = FastAPI()
create_conn_server(app, [ connarg ])
create_conn_server(app, connarg)


self.serv_task = asyncio.create_task(serve(app, self.config, self.serv_task = asyncio.create_task(serve(app, self.config,
shutdown_trigger=self.shutdown_event.wait)) shutdown_trigger=self.shutdown_event.wait))
@@ -754,7 +794,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase):


async with websockets.unix_connect(self.socketpath, async with websockets.unix_connect(self.socketpath,
'ws://foo/connect') as websocket, \ 'ws://foo/connect') as websocket, \
WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client:
WSFWDClient(wrapper(websocket.recv),
wrapper(websocket.send)) as client:
mstdout = AsyncMock() mstdout = AsyncMock()


echodata = b'somedata' echodata = b'somedata'
@@ -778,7 +819,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
await writer.drain() await writer.drain()


# that we get our data # that we get our data
self.assertEqual(await reader.read(len(echodata)), echodata)
self.assertEqual(await reader.read(len(echodata)),
echodata)


# and that there is no more # and that there is no more
self.assertEqual(await reader.read(len(echodata)), b'') self.assertEqual(await reader.read(len(echodata)), b'')
@@ -896,26 +938,50 @@ class TestMain(unittest.IsolatedAsyncioTestCase):
return ret, stdoutvalue return ret, stdoutvalue


def test_mainserver(self): def test_mainserver(self):
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__,
dict(args=[ ], stderr=fp)), \
self.assertRaises(SystemExit) as context: self.assertRaises(SystemExit) as context:
asyncio.run(real_main()) asyncio.run(real_main())


self.assertEqual(context.exception.code, 2) self.assertEqual(context.exception.code, 2)


with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'127.0.0.1:12345', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main() real_main()


hcm.assert_called_with(sys_args=[ 'wsfwd:app' ]) hcm.assert_called_with(sys_args=[ 'wsfwd:app' ])


self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)


with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main() real_main()


hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ])
hcm.assert_called_with(sys_args=['-b', '127.0.0.1',
'wsfwd:app' ])


self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345',
'another', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()

hcm.assert_called_with(sys_args=['-b', '127.0.0.1',
'wsfwd:app' ])

self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345', 'another')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI) self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)


@timeout(2) @timeout(2)
@@ -972,7 +1038,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
return WSFWDClient(self.toclient.get, self.toserver.put) return WSFWDClient(self.toclient.get, self.toserver.put)


def runFakeServer(self, func): def runFakeServer(self, func):
return asyncio.create_task(func(self.toserver.get, self.toclient.put))
return asyncio.create_task(func(self.toserver.get,
self.toclient.put))


@timeout(2) @timeout(2)
@patch('wsfwd.WSFWDCommon.shutdown') @patch('wsfwd.WSFWDCommon.shutdown')
@@ -991,7 +1058,8 @@ class Test(unittest.IsolatedAsyncioTestCase):


msg = self.decode_cmdmsg(cmd) msg = self.decode_cmdmsg(cmd)
self.assertEqual(msg['cmd'], 'auth') self.assertEqual(msg['cmd'], 'auth')
await writer(self._encodecmd(dict(resp='auth', id=msg['id'], error='Invalid auth')))
await writer(self._encodecmd(dict(resp='auth',
id=msg['id'], error='Invalid auth')))


serv_task = self.runFakeServer(fake_server) serv_task = self.runFakeServer(fake_server)


@@ -1100,7 +1168,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
ccmsg = await reader() ccmsg = await reader()
msg = self.decode_cmdmsg(ccmsg) msg = self.decode_cmdmsg(ccmsg)


self.assertEqual(msg, dict(resp='chanclose', id=1, chan=2))
self.assertEqual(msg, dict(resp='chanclose', id=1,
chan=2))


# return the exit code # return the exit code
await writer(self._encodecmd(dict(cmd='exit', id=2, await writer(self._encodecmd(dict(cmd='exit', id=2,
@@ -1120,9 +1189,11 @@ class Test(unittest.IsolatedAsyncioTestCase):


# the when exec fails w/ error it is caught # the when exec fails w/ error it is caught
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
await a.exec(args=execargs, stdin=PIPE, stdout=PIPE)
await a.exec(args=execargs, stdin=PIPE,
stdout=PIPE)


proc = await a.exec(args=execargs, stdin=PIPE, stdout=PIPE)
proc = await a.exec(args=execargs, stdin=PIPE,
stdout=PIPE)


writer, reader = proc.stdin, proc.stdout writer, reader = proc.stdin, proc.stdout


@@ -1164,7 +1235,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
self.assertFalse(procwaittask.done()) self.assertFalse(procwaittask.done())


# that the wait_closed # that the wait_closed
waitclosedtask = asyncio.create_task(writer.wait_closed())
waitclosedtask = asyncio.create_task(
writer.wait_closed())


# and when allowed to run # and when allowed to run
await asyncio.sleep(0) await asyncio.sleep(0)
@@ -1218,19 +1290,23 @@ class Test(unittest.IsolatedAsyncioTestCase):
r = json.loads(r[1:]) r = json.loads(r[1:])
r['resp'] = r['cmd'] r['resp'] = r['cmd']
del r['cmd'] del r['cmd']
await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8'))
await self.toserver.put(b'\x00' +
json.dumps(r).encode('utf-8'))


t = asyncio.create_task(task()) t = asyncio.create_task(task())
t2 = asyncio.create_task(task()) t2 = asyncio.create_task(task())


async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn:
async with WSFWDCommon(self.toserver.get,
self.toclient.put) as cmn:
r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar')) r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))


self.assertEqual(r, dict(resp='somecmd', id=1, foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=1,
foo='bar'))


r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar')) r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))


self.assertEqual(r, dict(resp='somecmd', id=2, foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=2,
foo='bar'))


await t await t
await t2 await t2
@@ -1247,20 +1323,24 @@ class Test(unittest.IsolatedAsyncioTestCase):
r = json.loads(r[1:]) r = json.loads(r[1:])
r['resp'] = r['cmd'] r['resp'] = r['cmd']
del r['cmd'] del r['cmd']
await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8'))
await self.toserver.put(b'\x00' +
json.dumps(r).encode('utf-8'))


t = asyncio.create_task(task()) t = asyncio.create_task(task())
t2 = asyncio.create_task(task()) t2 = asyncio.create_task(task())


async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn:
async with WSFWDCommon(self.toserver.get,
self.toclient.put) as cmn:
try: try:
r = asyncio.create_task(cmn.sendcmd(dict(cmd='somecmd', foo='bar')))
r = asyncio.create_task(cmn.sendcmd(dict(
cmd='somecmd', foo='bar')))


loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.call_later(.1, ev.set) loop.call_later(.1, ev.set)


# make sure that we can schedule a second one # make sure that we can schedule a second one
await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))
await cmn.sendcmd(dict(cmd='somecmd',
foo='bar'))
finally: finally:
await r await r




Loading…
Cancel
Save