@@ -1,13 +1,28 @@
import aioconsole.stream
import argparse
import asyncio
import contextlib
import functools
import io
import itertools
import json
import os
import shlex
import shutil
import sys
import tempfile
import unittest
import websockets
import traceback
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.websockets import WebSocket
from hypercorn.__main__ import main as hypercorn_main
from hypercorn.config import Config
from hypercorn.asyncio import serve
from io import TextIOWrapper
from typing import Dict, Any
from unittest.mock import patch, Mock, AsyncMock
@@ -15,6 +30,18 @@ PIPE = object()
__all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ]
def _debprint(*args): # pragma: no cover
import traceback, sys, os.path
st = traceback.extract_stack(limit=2)[0]
sep = ''
if args:
sep = ':'
print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep),
*args, file=sys.stderr)
sys.stderr.flush()
def timeout(timeout):
def timeout_wrapper(fun):
@functools.wraps(fun)
@@ -25,13 +52,25 @@ def timeout(timeout):
return timeout_wrapper
def _tbprinter(fun): #pragma: no cover
def _a tbprinter(fun): #pragma: no cover
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
try:
return await fun(*args, **kwargs)
except Exception:
print('in tbprinter:', repr(fun))
_debprint('in atbprinter:', repr(fun))
traceback.print_exc()
raise
return wrapper
def _tbprinter(fun): #pragma: no cover
@functools.wraps(fun)
def wrapper(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception:
_debprint('in tbprinter:', repr(fun))
traceback.print_exc()
raise
@@ -337,6 +376,8 @@ class WSFWDCommon:
self._waitingresp[cmd['id']] = fut
#_debprint('sendcmd:', repr(self), repr(cmd))
await self._sendcmd(cmd)
return await fut
@@ -347,13 +388,14 @@ class WSFWDCommon:
can be used for either commands or response to commands.
'''
#_debprint('_sendcmd:', repr(self), repr(self._writer), repr(cmd))
await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8'))
async def _proccmds(self, msg):
'''Interal routine for dispatching commands.'''
msg = json.loads(msg)
#print('_proccmds:', repr(self), repr(msg))
#_deb print('_proccmds:', repr(self), repr(msg))
if 'cmd' in msg:
handler = getattr(self, 'handle_%s' % msg['cmd'])
msgid = msg['id']
@@ -367,11 +409,13 @@ class WSFWDCommon:
error=e.args[0])
# assert resp['id'] == msgid
#print('response:', repr(resp))
#_deb print('response:', repr(resp))
await self._sendcmd(resp)
#_debprint('sent:', repr(resp))
elif 'resp' in msg:
fut = self._waitingresp.pop(msg['id'] )
#_debprint('got resp:', repr(self), repr(msg) )
fut = self._waitingresp.pop(msg['id'])
fut.set_result(msg)
class WSFWDServer(WSFWDCommon):
@@ -420,6 +464,27 @@ class WSFWDClient(WSFWDCommon):
async def _pushdata(writer, data):
writer.feed_data(data)
async def connect(self, args):
'''
Returns a StreamReader, StreamWriter pair.
'''
# get the stdin/stdout setup
self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader()
self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)
self._proc = None
rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2))
if 'error' in rsp:
raise RuntimeError(rsp['error'])
return self._stdout, self._stdin
async def exec(self, args, stdin=PIPE, stdout=PIPE):
'''
Returns a WFProcess instance. WFProcess is very similar
@@ -442,6 +507,453 @@ class WSFWDClient(WSFWDCommon):
return self._proc
def convert_to_ws(url):
if url.startswith('http://'):
url = url.replace('http', 'ws', 1)
elif url.startswith('https://'):
url = url.replace('https', 'wss', 1)
return url
# how to do stdin/stdout via async:
# https://github.com/vxgmichel/aioconsole/blob/master/aioconsole/stream.py#L130
async def fwd_data(reader, writer):
while True:
data = await reader.read(16384)
if data == b'':
#_debprint('fwd_data eof', repr(reader), repr(writer))
writer.close()
await writer.wait_closed()
#_debprint('fwd_data done', repr(reader), repr(writer))
return
#_debprint('fwd_data data', repr(reader), repr(writer), len(data))
writer.write(data)
await writer.drain()
async def run_connect(url, ipport):
url = convert_to_ws(url)
stdin, stdout = await aioconsole.stream.get_standard_streams()
async with websockets.connect(url) as ws, WSFWDClient(ws.recv,
ws.send) as client:
try:
rdr, wtr = await client.connect(args=ipport)
#_debprint('in', repr(stdin), repr(wtr))
#_debprint('out', repr(rdr), repr(stdout))
await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout))
sys.exit(0)
except RuntimeError as e:
print('failed to exec: %s' % e.args)
# not a fan of this, shouldn't be needed, but
# how tests are run w/ runAsyncMain, it is
# required here.
sys.stdout.flush()
sys.exit(1)
class HandleConnectLimited(WSFWDServer):
def __init__(self, *args, limited, **kwargs):
super().__init__(*args, **kwargs)
self._limited = limited
self._did_connect = False
self._finish_handler = asyncio.Event()
async def shutdown(self):
pass
def _validatearg(self, ipport):
if ipport in self._limited:
return
raise RuntimeError('not allowed ipport: %s' % repr(ipport))
async def process_writer(self, data):
wrt = self.__writer
wrt.write(data)
await wrt.drain()
async def process_reader(self):
rdr = self.__reader
stream = self._reader_stream
try:
while True:
data = await rdr.read(16384)
if not data:
break
self.sendstream(stream, data)
await self.drain(stream)
finally:
await self.sendcmd(dict(cmd='chanclose', chan=stream))
async def handle_chanclose(self, msg):
self.clear_stream_handler(self._writer_stream)
self.__writer.close()
await self.__writer.wait_closed()
self._writer_event.set()
async def handle_connect(self, msg):
if self._did_connect:
raise RuntimeError('already did connect')
ipport = msg['args']
self._validatearg(ipport)
host, port = ipport.split(':')
port = int(port)
self.__reader, self.__writer = await asyncio.open_connection(
host=host, port=port)
self._did_connect = True
self._writer_stream = msg['stdin']
self._reader_stream = msg['stdout']
# handle writer
self._writer_event = asyncio.Event()
self.add_stream_handler(msg['stdin'], self.process_writer)
# handle reader
self._reader_task = asyncio.create_task(self.process_reader())
async def get_finish_handler(self):
return await self._finish_handler.wait()
def create_conn_server(app, args):
@app.websocket('/connect')
async def connect_ws(webSocket: WebSocket):
await webSocket.accept()
try:
#_debprint('ccs:', repr(webSocket), repr(webSocket.receive_bytes), repr(webSocket.send_bytes))
async with HandleConnectLimited(webSocket.receive_bytes,
webSocket.send_bytes, limited=args) as server:
await server.get_finish_handler()
finally:
await webSocket.close()
@_tbprinter
def real_main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(title='subcommands',
dest='subparser_name',
description='valid subcommands', help='additional help')
parser_connect = subparsers.add_parser('connect',
help='connect to a socket at the specified URL')
parser_connect.add_argument('url', type=str,
help='the URL to issue the connect command to')
parser_connect.add_argument('ipport', type=str, help='<ip address>:<port>')
parser_serve = subparsers.add_parser('serve',
help='Serve connection requests to the provided <ip>:<port> tuples.')
parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+', help='<ip address>[:<port>]')
args = parser.parse_args()
#print(repr(args), file=sys.__stderr__)
if args.subparser_name == 'connect':
return run_connect(args.url, args.ipport)
elif args.subparser_name == 'serve':
# make hypercorn args
hypercornargs = sum((shlex.split(x) for x in args.hypercorn), [])
# make app
global app
app = FastAPI()
create_conn_server(app, args.ipport)
import wsfwd
#_debprint(repr(wsfwd.app))
hypercornargs.append('%s:app' % __name__)
hypercorn_main(sys_args=hypercornargs)
return
parser.print_usage()
async def fun():
sys.exit(5)
return fun()
def wrap_connect(mockobj, readdata=b''):
reader = Mock()
reader.read = AsyncMock()
reader.read.side_effect = [ readdata, b'' ]
writer = Mock()
writer.drain = AsyncMock()
writer.wait_closed = AsyncMock()
mockobj.return_value = (reader, writer)
class TestCTW(unittest.TestCase):
def test_convert_to_ws(self):
self.assertEqual(convert_to_ws('http://foo/'), 'ws://foo/')
self.assertEqual(convert_to_ws('https://foo/'), 'wss://foo/')
class TestServer(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d
self.shutdown_event = asyncio.Event()
self.socketpath = os.path.join(self.basetempdir, 'wstest.sock')
config = Config()
config.graceful_timeout = .01
config.bind = [ 'unix:' + self.socketpath ]
config.loglevel = 'ERROR'
self.config = config
async def asyncTearDown(self):
self.app = None
self.shutdown_event.set()
await self.serv_task
shutil.rmtree(self.basetempdir)
self.basetempdir = None
@patch('asyncio.open_connection')
@timeout(2)
async def test_connect(self, oc):
connarg = '127.0.0.1:12345'
app = FastAPI()
create_conn_server(app, [ connarg ])
self.serv_task = asyncio.create_task(serve(app, self.config,
shutdown_trigger=self.shutdown_event.wait))
# get the unix domain socket connected
# need a startup_trigger
await asyncio.sleep(.01)
def wrapper(corofun):
async def foo(*args, **kwargs):
r = await corofun(*args, **kwargs)
#print('foo:', repr(corofun), repr((args, kwargs)), repr(r))
return r
return foo
async with websockets.unix_connect(self.socketpath,
'ws://foo/connect') as websocket, \
WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client:
mstdout = AsyncMock()
echodata = b'somedata'
wrap_connect(oc, readdata=echodata)
# that bad args are rejected
with self.assertRaises(RuntimeError):
await client.connect('192.168.0.1:12345')
with self.assertRaises(RuntimeError):
await client.connect('127.0.0.1:348')
client.add_stream_handler(2, mstdout)
reader, writer = await client.connect(connarg)
# that it cannot be connected'd a second time
with self.assertRaises(RuntimeError):
await client.connect(connarg)
writer.write(echodata)
await writer.drain()
# that we get our data
self.assertEqual(await reader.read(len(echodata)), echodata)
# and that there is no more
self.assertEqual(await reader.read(len(echodata)), b'')
# and we are truly at EOF
self.assertTrue(reader.at_eof())
writer.close()
await writer.wait_closed()
# make sure that it was called w/ the correct arugments
oc.assert_called_with(host='127.0.0.1', port=12345)
# spin things, not sure best way to handle this
await asyncio.sleep(.01)
# make sure that the writer was closed
oc.return_value[1].close.assert_called_with()
class TestMain(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.toclient = asyncio.Queue()
self.toserver = asyncio.Queue()
async def asyncTearDown(self):
self.assertTrue(self.toclient.empty())
self.assertTrue(self.toserver.empty())
def setup_websockets_mock(self, webcon):
conobj = Mock()
webcon().__aenter__.return_value = conobj
conobj.send = self.toserver.put
conobj.recv = self.toclient.get
webcon.reset_mock()
@contextlib.contextmanager
def make_pipe(self):
r, w = os.pipe()
with os.fdopen(r, 'rb', buffering=65536) as readfl, \
os.fdopen(w, 'wb', buffering=65536) as writefl:
yield readfl, writefl
# too lazy to make this async since async file-like objects
# aren't standard yet in Python
def copytask(self, reader, writer, doclose=True):
while True:
data = reader.read(16384)
#_debprint('ct', repr(reader), repr(writer), doclose, len(data))
if not data:
if doclose:
writer.close()
#_debprint('ct done', repr(reader), repr(writer), doclose)
return
writer.write(data)
@_atbprinter
async def runAsyncMain(self, fun=real_main, stdin=''):
# make stdin bytes
if isinstance(stdin, str):
stdin = stdin.encode()
# make stdout
stdout = io.BytesIO()
# Data path:
# stdin -> stdin_task -> stdinwriter -> pipe ->
# stdinreader -> sys.stdin -> real_main -> sys.stdout ->
# stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout
#
# How things get closed:
# stdin is already "closed", stdin_task will close
# stdinwriter when eof encountered (doclose).
# create the pipes needed
with self.make_pipe() as (stdinreader, stdinwriter), \
self.make_pipe() as (stdoutreader, stdoutwriter):
# setup the threads to move data
loop = asyncio.get_running_loop()
stdin_task = loop.run_in_executor(None, self.copytask,
io.BytesIO(stdin), stdinwriter)
# do not close stdout, otherwise we cannot call
# getvalue on BytesIO object
stdout_task = loop.run_in_executor(None, self.copytask,
stdoutreader, stdout, False)
#_debprint('ram', repr(stdout_task), repr(stdoutreader))
# insert the pipes
with patch.dict(sys.__dict__,
dict(stdin=TextIOWrapper(stdinreader),
stdout=TextIOWrapper(stdoutwriter))):
try:
# run the function
await fun()
#await asyncio.wait_for(fun(), 1)
ret = 0 #pragma: no cover
except SystemExit as e:
#_debprint('exit', repr(e.code))
ret = e.code
# No one to write anything anymore
# close so stdout_task will end
stdoutwriter.close()
# make sure all the data has been copied
await asyncio.gather(stdin_task, stdout_task)
stdoutvalue = stdout.getvalue()
return ret, stdoutvalue
def test_mainserver(self):
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \
self.assertRaises(SystemExit) as context:
asyncio.run(real_main())
self.assertEqual(context.exception.code, 2)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
real_main()
hcm.assert_called_with(sys_args=[ 'wsfwd:app' ])
self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
real_main()
hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ])
self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)
@timeout(2)
async def test_socket(self):
class TestServer(WSFWDCommon):
async def echo_handler(self, stream, msg):
self.sendstream(stream, msg)
await self.drain(stream)
async def handle_chanclose(self, msg):
self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='chanclose',
chan=self._stdout_stream))))
async def handle_connect(self, msg):
self._stdout_stream = msg['stdout']
assert msg['args'] == '127.0.0.1:38493'
self.add_stream_handler(msg['stdin'],
functools.partial(self.echo_handler,
msg['stdout']))
server = TestServer(self.toserver.get, self.toclient.put)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'connect',
'https://example.com/connectws', '127.0.0.1:38493', ])), \
patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon)
inpdata = bytes(range(0, 255))
ret, stdout = await self.runAsyncMain(stdin=inpdata)
await server.__aexit__(None, None, None)
self.assertEqual(stdout, inpdata)
self.assertEqual(ret, 0)
class Test(unittest.IsolatedAsyncioTestCase):
@staticmethod
def _encodecmd(payload):
@@ -474,7 +986,6 @@ class Test(unittest.IsolatedAsyncioTestCase):
@timeout(2)
async def test_authfail(self):
@_tbprinter
async def fake_server(reader, writer):
cmd = await reader()
@@ -508,7 +1019,6 @@ class Test(unittest.IsolatedAsyncioTestCase):
writerdata2 = b'oiwuersldkj'
readerdata1 = b'weoiusdofiuwe'
@_tbprinter
async def fake_server(reader, writer):
# get the auth command
auth = await reader()