Browse Source

add some documentation, wrap a few long lines..

main
John-Mark Gurney 3 years ago
parent
commit
fd2c5f9fba
2 changed files with 179 additions and 40 deletions
  1. +60
    -1
      README.md
  2. +119
    -39
      wsfwd/__init__.py

+ 60
- 1
README.md View File

@@ -9,7 +9,64 @@ port blocking, or allow more custom routing and execution.
It is designed so that in the future, it could support forwarding
stderr separately, but also out of band messages, such as window
change information, so that a full tty could be forwarded over the
connection.
connection. As transporting a protocol like ssh does all of this
for you, it is unlikely to be expanded to support it.

Usage
-----

Included is a sample program that can be used to forward stream
connections to one of the specified hosts and port.

First install the dependencies. Note that often it's best to install
it into a virtual environment[venv], so, create one:
```
python -m venv dirname
```

Where dirname will be a directory that will be created w/ the virtual
environment. Start the virtual environment (if you're using bash/zsh,
other shells are documented in the link above):
```
source ./dirname/bin/activate
```

Install wsfwd:
```
pip install git+https://www.funkthat.com/gitea/jmg/wsfwd
```

Run the server (for example to forward to the local machine's sshd):
```
wsfwd serve 127.0.0.1:22
```

Multiple `ip:port` maybe specified on the command to allow forwarding
to multiple hosts.

This will start up a webserver. It uses hypercorn, so, all the
hypercorn command line arguments may be specified via the `--hypercorn`
argument. For example, to bind only to 127.0.0.1, you can run:
```
wsfwd serve --hypercorn '--bind 127.0.0.1' 127.0.0.1:22
```

Once that has been setup, the client can be used to connect to the
specified server:
```
wsfwd connect http://127.0.0.1/connect 127.0.0.1:22
```

This is most useful to be specified to ProxyCommand in your .ssh/config
file:
```
Host localhost
ProxyCommand /path/to/dirname/bin/wsfwd connect http://127.0.0.1:8000/connect 127.0.0.1:22
```

There are many ways to use it. Another option could also be to launch
it via inetd to transparently pass normal ssh sessions to remote hosts.


Protocol
--------
@@ -107,3 +164,5 @@ with other implementations.
FastAPI uses starlette: https://www.starlette.io/websockets/

Client: https://github.com/aaugustin/websockets

[venv]: https://docs.python.org/3/library/venv.html

+ 119
- 39
wsfwd/__init__.py View File

@@ -28,7 +28,13 @@ from unittest.mock import patch, Mock, AsyncMock

PIPE = object()

__all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ]
__all__ = [
'create_conn_server',
'HandleConnectLimited',
'WSFWDServer',
'WSFWDClient',
'WFProcess',
]

def _debprint(*args): # pragma: no cover
import traceback, sys, os.path
@@ -46,7 +52,8 @@ def timeout(timeout):
def timeout_wrapper(fun):
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
return await asyncio.wait_for(fun(*args, **kwargs), timeout)
return await asyncio.wait_for(fun(*args, **kwargs),
timeout)

return wrapper

@@ -208,7 +215,8 @@ class WFStreamWriter:
return self._closed

async def _close_task(self):
await self._client.sendcmd(dict(cmd='chanclose', chan=self._stream))
await self._client.sendcmd(dict(cmd='chanclose',
chan=self._stream))

self._closed_event.set()
self._closed_event = None
@@ -342,7 +350,8 @@ class WSFWDCommon:
senddata = b''.join(data)

if senddata:
await self._writer(stream.to_bytes(1, 'big') + senddata)
await self._writer(stream.to_bytes(1, 'big') +
senddata)

async def __aenter__(self):
return self
@@ -458,7 +467,8 @@ class WSFWDClient(WSFWDCommon):
rsp = await self.sendcmd(dict(cmd='auth', auth=auth))

if 'error' in rsp:
raise RuntimeError('Got auth error: %s' % repr(rsp['error']))
raise RuntimeError('Got auth error: %s' %
repr(rsp['error']))

@staticmethod
async def _pushdata(writer, data):
@@ -474,11 +484,13 @@ class WSFWDClient(WSFWDCommon):
self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader()

self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._procmsgs[2] = functools.partial(self._pushdata,
self._stdout)

self._proc = None

rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2))
rsp = await self.sendcmd(dict(cmd='connect', args=args,
stdin=1, stdout=2))

if 'error' in rsp:
raise RuntimeError(rsp['error'])
@@ -496,11 +508,13 @@ class WSFWDClient(WSFWDCommon):
self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader()

self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._procmsgs[2] = functools.partial(self._pushdata,
self._stdout)

self._proc = WFProcess(self, self._stdin, self._stdout)

rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1, stdout=2))
rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1,
stdout=2))

if 'error' in rsp:
raise RuntimeError(rsp['error'])
@@ -544,7 +558,8 @@ async def run_connect(url, ipport):
#_debprint('in', repr(stdin), repr(wtr))
#_debprint('out', repr(rdr), repr(stdout))

await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout))
await asyncio.gather(fwd_data(stdin, wtr),
fwd_data(rdr, stdout))

sys.exit(0)

@@ -559,6 +574,14 @@ async def run_connect(url, ipport):
sys.exit(1)

class HandleConnectLimited(WSFWDServer):
'''
A server that allows connection to the specified list of ip:port's
specified in the arguemnt `limited`.

It is recommend to use `create_conn_server` directly, instead of
instantiating it yourself.
'''

def __init__(self, *args, limited, **kwargs):
super().__init__(*args, **kwargs)

@@ -629,7 +652,19 @@ class HandleConnectLimited(WSFWDServer):
async def get_finish_handler(self):
return await self._finish_handler.wait()

def create_conn_server(app, args):
def create_conn_server(app, *args):
'''Add a route to app for /connect, that will allow connections
to the provided ip:port.

Example usage:
```
from fastapi import FastAPI
from wsfwd import create_conn_server
app = FastAPI()
create_conn_server(app, '127.0.0.1:22')
```
'''

@app.websocket('/connect')
async def connect_ws(webSocket: WebSocket):
await webSocket.accept()
@@ -653,12 +688,16 @@ def real_main():
help='connect to a socket at the specified URL')
parser_connect.add_argument('url', type=str,
help='the URL to issue the connect command to')
parser_connect.add_argument('ipport', type=str, help='<ip address>:<port>')
parser_connect.add_argument('ipport', type=str,
help='<ip address>:<port>')

parser_serve = subparsers.add_parser('serve',
help='Serve connection requests to the provided <ip>:<port> tuples.')
parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+', help='<ip address>[:<port>]')
parser_serve.add_argument('--hypercorn', type=str, action='append',
default=[],
help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+',
help='<ip address>[:<port>]')

args = parser.parse_args()
#print(repr(args), file=sys.__stderr__)
@@ -667,13 +706,14 @@ def real_main():
return run_connect(args.url, args.ipport)
elif args.subparser_name == 'serve':
# make hypercorn args
hypercornargs = sum((shlex.split(x) for x in args.hypercorn), [])
hypercornargs = sum((shlex.split(x) for x in args.hypercorn),
[])

# make app
global app

app = FastAPI()
create_conn_server(app, args.ipport)
create_conn_server(app, *args.ipport)

import wsfwd
#_debprint(repr(wsfwd.app))
@@ -735,7 +775,7 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
connarg = '127.0.0.1:12345'

app = FastAPI()
create_conn_server(app, [ connarg ])
create_conn_server(app, connarg)

self.serv_task = asyncio.create_task(serve(app, self.config,
shutdown_trigger=self.shutdown_event.wait))
@@ -754,7 +794,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase):

async with websockets.unix_connect(self.socketpath,
'ws://foo/connect') as websocket, \
WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client:
WSFWDClient(wrapper(websocket.recv),
wrapper(websocket.send)) as client:
mstdout = AsyncMock()

echodata = b'somedata'
@@ -778,7 +819,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
await writer.drain()

# that we get our data
self.assertEqual(await reader.read(len(echodata)), echodata)
self.assertEqual(await reader.read(len(echodata)),
echodata)

# and that there is no more
self.assertEqual(await reader.read(len(echodata)), b'')
@@ -896,26 +938,50 @@ class TestMain(unittest.IsolatedAsyncioTestCase):
return ret, stdoutvalue

def test_mainserver(self):
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__,
dict(args=[ ], stderr=fp)), \
self.assertRaises(SystemExit) as context:
asyncio.run(real_main())

self.assertEqual(context.exception.code, 2)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'127.0.0.1:12345', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()

hcm.assert_called_with(sys_args=[ 'wsfwd:app' ])

self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()

hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ])
hcm.assert_called_with(sys_args=['-b', '127.0.0.1',
'wsfwd:app' ])

self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345',
'another', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()

hcm.assert_called_with(sys_args=['-b', '127.0.0.1',
'wsfwd:app' ])

self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345', 'another')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)

@timeout(2)
@@ -972,7 +1038,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
return WSFWDClient(self.toclient.get, self.toserver.put)

def runFakeServer(self, func):
return asyncio.create_task(func(self.toserver.get, self.toclient.put))
return asyncio.create_task(func(self.toserver.get,
self.toclient.put))

@timeout(2)
@patch('wsfwd.WSFWDCommon.shutdown')
@@ -991,7 +1058,8 @@ class Test(unittest.IsolatedAsyncioTestCase):

msg = self.decode_cmdmsg(cmd)
self.assertEqual(msg['cmd'], 'auth')
await writer(self._encodecmd(dict(resp='auth', id=msg['id'], error='Invalid auth')))
await writer(self._encodecmd(dict(resp='auth',
id=msg['id'], error='Invalid auth')))

serv_task = self.runFakeServer(fake_server)

@@ -1100,7 +1168,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
ccmsg = await reader()
msg = self.decode_cmdmsg(ccmsg)

self.assertEqual(msg, dict(resp='chanclose', id=1, chan=2))
self.assertEqual(msg, dict(resp='chanclose', id=1,
chan=2))

# return the exit code
await writer(self._encodecmd(dict(cmd='exit', id=2,
@@ -1120,9 +1189,11 @@ class Test(unittest.IsolatedAsyncioTestCase):

# the when exec fails w/ error it is caught
with self.assertRaises(RuntimeError):
await a.exec(args=execargs, stdin=PIPE, stdout=PIPE)
await a.exec(args=execargs, stdin=PIPE,
stdout=PIPE)

proc = await a.exec(args=execargs, stdin=PIPE, stdout=PIPE)
proc = await a.exec(args=execargs, stdin=PIPE,
stdout=PIPE)

writer, reader = proc.stdin, proc.stdout

@@ -1164,7 +1235,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
self.assertFalse(procwaittask.done())

# that the wait_closed
waitclosedtask = asyncio.create_task(writer.wait_closed())
waitclosedtask = asyncio.create_task(
writer.wait_closed())

# and when allowed to run
await asyncio.sleep(0)
@@ -1218,19 +1290,23 @@ class Test(unittest.IsolatedAsyncioTestCase):
r = json.loads(r[1:])
r['resp'] = r['cmd']
del r['cmd']
await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8'))
await self.toserver.put(b'\x00' +
json.dumps(r).encode('utf-8'))

t = asyncio.create_task(task())
t2 = asyncio.create_task(task())

async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn:
async with WSFWDCommon(self.toserver.get,
self.toclient.put) as cmn:
r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))

self.assertEqual(r, dict(resp='somecmd', id=1, foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=1,
foo='bar'))

r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))

self.assertEqual(r, dict(resp='somecmd', id=2, foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=2,
foo='bar'))

await t
await t2
@@ -1247,20 +1323,24 @@ class Test(unittest.IsolatedAsyncioTestCase):
r = json.loads(r[1:])
r['resp'] = r['cmd']
del r['cmd']
await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8'))
await self.toserver.put(b'\x00' +
json.dumps(r).encode('utf-8'))

t = asyncio.create_task(task())
t2 = asyncio.create_task(task())

async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn:
async with WSFWDCommon(self.toserver.get,
self.toclient.put) as cmn:
try:
r = asyncio.create_task(cmn.sendcmd(dict(cmd='somecmd', foo='bar')))
r = asyncio.create_task(cmn.sendcmd(dict(
cmd='somecmd', foo='bar')))

loop = asyncio.get_running_loop()
loop.call_later(.1, ev.set)

# make sure that we can schedule a second one
await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))
await cmn.sendcmd(dict(cmd='somecmd',
foo='bar'))
finally:
await r



Loading…
Cancel
Save