@@ -28,7 +28,13 @@ from unittest.mock import patch, Mock, AsyncMock
PIPE = object()
__all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ]
__all__ = [
'create_conn_server',
'HandleConnectLimited',
'WSFWDServer',
'WSFWDClient',
'WFProcess',
]
def _debprint(*args): # pragma: no cover
import traceback, sys, os.path
@@ -46,7 +52,8 @@ def timeout(timeout):
def timeout_wrapper(fun):
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
return await asyncio.wait_for(fun(*args, **kwargs), timeout)
return await asyncio.wait_for(fun(*args, **kwargs),
timeout)
return wrapper
@@ -208,7 +215,8 @@ class WFStreamWriter:
return self._closed
async def _close_task(self):
await self._client.sendcmd(dict(cmd='chanclose', chan=self._stream))
await self._client.sendcmd(dict(cmd='chanclose',
chan=self._stream))
self._closed_event.set()
self._closed_event = None
@@ -342,7 +350,8 @@ class WSFWDCommon:
senddata = b''.join(data)
if senddata:
await self._writer(stream.to_bytes(1, 'big') + senddata)
await self._writer(stream.to_bytes(1, 'big') +
senddata)
async def __aenter__(self):
return self
@@ -458,7 +467,8 @@ class WSFWDClient(WSFWDCommon):
rsp = await self.sendcmd(dict(cmd='auth', auth=auth))
if 'error' in rsp:
raise RuntimeError('Got auth error: %s' % repr(rsp['error']))
raise RuntimeError('Got auth error: %s' %
repr(rsp['error']))
@staticmethod
async def _pushdata(writer, data):
@@ -474,11 +484,13 @@ class WSFWDClient(WSFWDCommon):
self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader()
self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._procmsgs[2] = functools.partial(self._pushdata,
self._stdout)
self._proc = None
rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2))
rsp = await self.sendcmd(dict(cmd='connect', args=args,
stdin=1, stdout=2))
if 'error' in rsp:
raise RuntimeError(rsp['error'])
@@ -496,11 +508,13 @@ class WSFWDClient(WSFWDCommon):
self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader()
self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._procmsgs[2] = functools.partial(self._pushdata,
self._stdout)
self._proc = WFProcess(self, self._stdin, self._stdout)
rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1, stdout=2))
rsp = await self.sendcmd(dict(cmd='exec', args=args, stdin=1,
stdout=2))
if 'error' in rsp:
raise RuntimeError(rsp['error'])
@@ -544,7 +558,8 @@ async def run_connect(url, ipport):
#_debprint('in', repr(stdin), repr(wtr))
#_debprint('out', repr(rdr), repr(stdout))
await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout))
await asyncio.gather(fwd_data(stdin, wtr),
fwd_data(rdr, stdout))
sys.exit(0)
@@ -559,6 +574,14 @@ async def run_connect(url, ipport):
sys.exit(1)
class HandleConnectLimited(WSFWDServer):
'''
A server that allows connection to the specified list of ip:port's
specified in the arguemnt `limited`.
It is recommend to use `create_conn_server` directly, instead of
instantiating it yourself.
'''
def __init__(self, *args, limited, **kwargs):
super().__init__(*args, **kwargs)
@@ -629,7 +652,19 @@ class HandleConnectLimited(WSFWDServer):
async def get_finish_handler(self):
return await self._finish_handler.wait()
def create_conn_server(app, args):
def create_conn_server(app, *args):
'''Add a route to app for /connect, that will allow connections
to the provided ip:port.
Example usage:
```
from fastapi import FastAPI
from wsfwd import create_conn_server
app = FastAPI()
create_conn_server(app, '127.0.0.1:22')
```
'''
@app.websocket('/connect')
async def connect_ws(webSocket: WebSocket):
await webSocket.accept()
@@ -653,12 +688,16 @@ def real_main():
help='connect to a socket at the specified URL')
parser_connect.add_argument('url', type=str,
help='the URL to issue the connect command to')
parser_connect.add_argument('ipport', type=str, help='<ip address>:<port>')
parser_connect.add_argument('ipport', type=str,
help='<ip address>:<port>')
parser_serve = subparsers.add_parser('serve',
help='Serve connection requests to the provided <ip>:<port> tuples.')
parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+', help='<ip address>[:<port>]')
parser_serve.add_argument('--hypercorn', type=str, action='append',
default=[],
help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+',
help='<ip address>[:<port>]')
args = parser.parse_args()
#print(repr(args), file=sys.__stderr__)
@@ -667,13 +706,14 @@ def real_main():
return run_connect(args.url, args.ipport)
elif args.subparser_name == 'serve':
# make hypercorn args
hypercornargs = sum((shlex.split(x) for x in args.hypercorn), [])
hypercornargs = sum((shlex.split(x) for x in args.hypercorn),
[])
# make app
global app
app = FastAPI()
create_conn_server(app, args.ipport)
create_conn_server(app, * args.ipport)
import wsfwd
#_debprint(repr(wsfwd.app))
@@ -735,7 +775,7 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
connarg = '127.0.0.1:12345'
app = FastAPI()
create_conn_server(app, [ connarg ] )
create_conn_server(app, connarg)
self.serv_task = asyncio.create_task(serve(app, self.config,
shutdown_trigger=self.shutdown_event.wait))
@@ -754,7 +794,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
async with websockets.unix_connect(self.socketpath,
'ws://foo/connect') as websocket, \
WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client:
WSFWDClient(wrapper(websocket.recv),
wrapper(websocket.send)) as client:
mstdout = AsyncMock()
echodata = b'somedata'
@@ -778,7 +819,8 @@ class TestServer(unittest.IsolatedAsyncioTestCase):
await writer.drain()
# that we get our data
self.assertEqual(await reader.read(len(echodata)), echodata)
self.assertEqual(await reader.read(len(echodata)),
echodata)
# and that there is no more
self.assertEqual(await reader.read(len(echodata)), b'')
@@ -896,26 +938,50 @@ class TestMain(unittest.IsolatedAsyncioTestCase):
return ret, stdoutvalue
def test_mainserver(self):
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__,
dict(args=[ ], stderr=fp)), \
self.assertRaises(SystemExit) as context:
asyncio.run(real_main())
self.assertEqual(context.exception.code, 2)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'127.0.0.1:12345', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()
hcm.assert_called_with(sys_args=[ 'wsfwd:app' ])
self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()
hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ])
hcm.assert_called_with(sys_args=['-b', '127.0.0.1',
'wsfwd:app' ])
self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve',
'--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345',
'another', ])), \
patch('wsfwd.hypercorn_main') as hcm, \
patch('wsfwd.create_conn_server') as ccs:
real_main()
hcm.assert_called_with(sys_args=['-b', '127.0.0.1',
'wsfwd:app' ])
self.assertEqual(ccs.mock_calls[0][1][1],
'127.0.0.1:12345', 'another')
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)
@timeout(2)
@@ -972,7 +1038,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
return WSFWDClient(self.toclient.get, self.toserver.put)
def runFakeServer(self, func):
return asyncio.create_task(func(self.toserver.get, self.toclient.put))
return asyncio.create_task(func(self.toserver.get,
self.toclient.put))
@timeout(2)
@patch('wsfwd.WSFWDCommon.shutdown')
@@ -991,7 +1058,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
msg = self.decode_cmdmsg(cmd)
self.assertEqual(msg['cmd'], 'auth')
await writer(self._encodecmd(dict(resp='auth', id=msg['id'], error='Invalid auth')))
await writer(self._encodecmd(dict(resp='auth',
id=msg['id'], error='Invalid auth')))
serv_task = self.runFakeServer(fake_server)
@@ -1100,7 +1168,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
ccmsg = await reader()
msg = self.decode_cmdmsg(ccmsg)
self.assertEqual(msg, dict(resp='chanclose', id=1, chan=2))
self.assertEqual(msg, dict(resp='chanclose', id=1,
chan=2))
# return the exit code
await writer(self._encodecmd(dict(cmd='exit', id=2,
@@ -1120,9 +1189,11 @@ class Test(unittest.IsolatedAsyncioTestCase):
# the when exec fails w/ error it is caught
with self.assertRaises(RuntimeError):
await a.exec(args=execargs, stdin=PIPE, stdout=PIPE)
await a.exec(args=execargs, stdin=PIPE,
stdout=PIPE)
proc = await a.exec(args=execargs, stdin=PIPE, stdout=PIPE)
proc = await a.exec(args=execargs, stdin=PIPE,
stdout=PIPE)
writer, reader = proc.stdin, proc.stdout
@@ -1164,7 +1235,8 @@ class Test(unittest.IsolatedAsyncioTestCase):
self.assertFalse(procwaittask.done())
# that the wait_closed
waitclosedtask = asyncio.create_task(writer.wait_closed())
waitclosedtask = asyncio.create_task(
writer.wait_closed())
# and when allowed to run
await asyncio.sleep(0)
@@ -1218,19 +1290,23 @@ class Test(unittest.IsolatedAsyncioTestCase):
r = json.loads(r[1:])
r['resp'] = r['cmd']
del r['cmd']
await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8'))
await self.toserver.put(b'\x00' +
json.dumps(r).encode('utf-8'))
t = asyncio.create_task(task())
t2 = asyncio.create_task(task())
async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn:
async with WSFWDCommon(self.toserver.get,
self.toclient.put) as cmn:
r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=1, foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=1,
foo='bar'))
r = await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=2, foo='bar'))
self.assertEqual(r, dict(resp='somecmd', id=2,
foo='bar'))
await t
await t2
@@ -1247,20 +1323,24 @@ class Test(unittest.IsolatedAsyncioTestCase):
r = json.loads(r[1:])
r['resp'] = r['cmd']
del r['cmd']
await self.toserver.put(b'\x00' + json.dumps(r).encode('utf-8'))
await self.toserver.put(b'\x00' +
json.dumps(r).encode('utf-8'))
t = asyncio.create_task(task())
t2 = asyncio.create_task(task())
async with WSFWDCommon(self.toserver.get, self.toclient.put) as cmn:
async with WSFWDCommon(self.toserver.get,
self.toclient.put) as cmn:
try:
r = asyncio.create_task(cmn.sendcmd(dict(cmd='somecmd', foo='bar')))
r = asyncio.create_task(cmn.sendcmd(dict(
cmd='somecmd', foo='bar')))
loop = asyncio.get_running_loop()
loop.call_later(.1, ev.set)
# make sure that we can schedule a second one
await cmn.sendcmd(dict(cmd='somecmd', foo='bar'))
await cmn.sendcmd(dict(cmd='somecmd',
foo='bar'))
finally:
await r