Browse Source

add support for server and client to do forwarding to TCP streams...

This will be documented in a followup commit...
main
John-Mark Gurney 4 years ago
parent
commit
7f7f187bcc
3 changed files with 531 additions and 7 deletions
  1. +5
    -0
      setup.py
  2. +517
    -7
      wsfwd/__init__.py
  3. +9
    -0
      wsfwd/__main__.py

+ 5
- 0
setup.py View File

@@ -18,12 +18,17 @@ setup(
#download_url='',
long_description=open('README.md').read(),
install_requires=[
'aioconsole', # for aioconsole.stream only
'fastapi',
'hypercorn',
'websockets',
],
extras_require = {
'dev': [ 'coverage' ],
},
entry_points={
'console_scripts': [
'wsfwd = wsfwd.__main__:main',
]
}
)

+ 517
- 7
wsfwd/__init__.py View File

@@ -1,13 +1,28 @@
import aioconsole.stream
import argparse
import asyncio
import contextlib
import functools
import io
import itertools
import json
import os
import shlex
import shutil
import sys
import tempfile
import unittest
import websockets

import traceback

from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.websockets import WebSocket
from hypercorn.__main__ import main as hypercorn_main
from hypercorn.config import Config
from hypercorn.asyncio import serve
from io import TextIOWrapper
from typing import Dict, Any
from unittest.mock import patch, Mock, AsyncMock

@@ -15,6 +30,18 @@ PIPE = object()

__all__ = [ 'WSFWDServer', 'WSFWDClient', 'WFProcess', ]

def _debprint(*args): # pragma: no cover
import traceback, sys, os.path
st = traceback.extract_stack(limit=2)[0]

sep = ''
if args:
sep = ':'

print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep),
*args, file=sys.stderr)
sys.stderr.flush()

def timeout(timeout):
def timeout_wrapper(fun):
@functools.wraps(fun)
@@ -25,13 +52,25 @@ def timeout(timeout):

return timeout_wrapper

def _tbprinter(fun): #pragma: no cover
def _atbprinter(fun): #pragma: no cover
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
try:
return await fun(*args, **kwargs)
except Exception:
print('in tbprinter:', repr(fun))
_debprint('in atbprinter:', repr(fun))
traceback.print_exc()
raise

return wrapper

def _tbprinter(fun): #pragma: no cover
@functools.wraps(fun)
def wrapper(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception:
_debprint('in tbprinter:', repr(fun))
traceback.print_exc()
raise

@@ -337,6 +376,8 @@ class WSFWDCommon:

self._waitingresp[cmd['id']] = fut

#_debprint('sendcmd:', repr(self), repr(cmd))

await self._sendcmd(cmd)

return await fut
@@ -347,13 +388,14 @@ class WSFWDCommon:
can be used for either commands or response to commands.
'''

#_debprint('_sendcmd:', repr(self), repr(self._writer), repr(cmd))
await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8'))

async def _proccmds(self, msg):
'''Interal routine for dispatching commands.'''

msg = json.loads(msg)
#print('_proccmds:', repr(self), repr(msg))
#_debprint('_proccmds:', repr(self), repr(msg))
if 'cmd' in msg:
handler = getattr(self, 'handle_%s' % msg['cmd'])
msgid = msg['id']
@@ -367,11 +409,13 @@ class WSFWDCommon:
error=e.args[0])

# assert resp['id'] == msgid
#print('response:', repr(resp))
#_debprint('response:', repr(resp))
await self._sendcmd(resp)
#_debprint('sent:', repr(resp))
elif 'resp' in msg:
fut = self._waitingresp.pop(msg['id'])
#_debprint('got resp:', repr(self), repr(msg))

fut = self._waitingresp.pop(msg['id'])
fut.set_result(msg)

class WSFWDServer(WSFWDCommon):
@@ -420,6 +464,27 @@ class WSFWDClient(WSFWDCommon):
async def _pushdata(writer, data):
writer.feed_data(data)

async def connect(self, args):
'''
Returns a StreamReader, StreamWriter pair.
'''

# get the stdin/stdout setup

self._stdin = WFStreamWriter(self, 1)
self._stdout = asyncio.StreamReader()

self._procmsgs[2] = functools.partial(self._pushdata, self._stdout)

self._proc = None

rsp = await self.sendcmd(dict(cmd='connect', args=args, stdin=1, stdout=2))

if 'error' in rsp:
raise RuntimeError(rsp['error'])

return self._stdout, self._stdin

async def exec(self, args, stdin=PIPE, stdout=PIPE):
'''
Returns a WFProcess instance. WFProcess is very similar
@@ -442,6 +507,453 @@ class WSFWDClient(WSFWDCommon):

return self._proc

def convert_to_ws(url):
if url.startswith('http://'):
url = url.replace('http', 'ws', 1)
elif url.startswith('https://'):
url = url.replace('https', 'wss', 1)

return url

# how to do stdin/stdout via async:
# https://github.com/vxgmichel/aioconsole/blob/master/aioconsole/stream.py#L130
async def fwd_data(reader, writer):
while True:
data = await reader.read(16384)
if data == b'':
#_debprint('fwd_data eof', repr(reader), repr(writer))
writer.close()
await writer.wait_closed()
#_debprint('fwd_data done', repr(reader), repr(writer))
return

#_debprint('fwd_data data', repr(reader), repr(writer), len(data))
writer.write(data)

await writer.drain()

async def run_connect(url, ipport):
url = convert_to_ws(url)
stdin, stdout = await aioconsole.stream.get_standard_streams()

async with websockets.connect(url) as ws, WSFWDClient(ws.recv,
ws.send) as client:
try:
rdr, wtr = await client.connect(args=ipport)

#_debprint('in', repr(stdin), repr(wtr))
#_debprint('out', repr(rdr), repr(stdout))

await asyncio.gather(fwd_data(stdin, wtr), fwd_data(rdr, stdout))

sys.exit(0)

except RuntimeError as e:
print('failed to exec: %s' % e.args)

# not a fan of this, shouldn't be needed, but
# how tests are run w/ runAsyncMain, it is
# required here.
sys.stdout.flush()

sys.exit(1)

class HandleConnectLimited(WSFWDServer):
def __init__(self, *args, limited, **kwargs):
super().__init__(*args, **kwargs)

self._limited = limited
self._did_connect = False
self._finish_handler = asyncio.Event()

async def shutdown(self):
pass

def _validatearg(self, ipport):
if ipport in self._limited:
return

raise RuntimeError('not allowed ipport: %s' % repr(ipport))

async def process_writer(self, data):
wrt = self.__writer
wrt.write(data)
await wrt.drain()

async def process_reader(self):
rdr = self.__reader
stream = self._reader_stream

try:
while True:
data = await rdr.read(16384)
if not data:
break
self.sendstream(stream, data)
await self.drain(stream)
finally:
await self.sendcmd(dict(cmd='chanclose', chan=stream))

async def handle_chanclose(self, msg):
self.clear_stream_handler(self._writer_stream)
self.__writer.close()
await self.__writer.wait_closed()
self._writer_event.set()

async def handle_connect(self, msg):
if self._did_connect:
raise RuntimeError('already did connect')

ipport = msg['args']

self._validatearg(ipport)

host, port = ipport.split(':')
port = int(port)

self.__reader, self.__writer = await asyncio.open_connection(
host=host, port=port)

self._did_connect = True

self._writer_stream = msg['stdin']
self._reader_stream = msg['stdout']

# handle writer
self._writer_event = asyncio.Event()
self.add_stream_handler(msg['stdin'], self.process_writer)

# handle reader
self._reader_task = asyncio.create_task(self.process_reader())

async def get_finish_handler(self):
return await self._finish_handler.wait()

def create_conn_server(app, args):
@app.websocket('/connect')
async def connect_ws(webSocket: WebSocket):
await webSocket.accept()
try:
#_debprint('ccs:', repr(webSocket), repr(webSocket.receive_bytes), repr(webSocket.send_bytes))
async with HandleConnectLimited(webSocket.receive_bytes,
webSocket.send_bytes, limited=args) as server:
await server.get_finish_handler()
finally:
await webSocket.close()

@_tbprinter
def real_main():
parser = argparse.ArgumentParser()

subparsers = parser.add_subparsers(title='subcommands',
dest='subparser_name',
description='valid subcommands', help='additional help')

parser_connect = subparsers.add_parser('connect',
help='connect to a socket at the specified URL')
parser_connect.add_argument('url', type=str,
help='the URL to issue the connect command to')
parser_connect.add_argument('ipport', type=str, help='<ip address>:<port>')

parser_serve = subparsers.add_parser('serve',
help='Serve connection requests to the provided <ip>:<port> tuples.')
parser_serve.add_argument('--hypercorn', type=str, action='append', default=[], help='arguments to hypercorn, will be split per standard sh rules')
parser_serve.add_argument('ipport', type=str, nargs='+', help='<ip address>[:<port>]')

args = parser.parse_args()
#print(repr(args), file=sys.__stderr__)

if args.subparser_name == 'connect':
return run_connect(args.url, args.ipport)
elif args.subparser_name == 'serve':
# make hypercorn args
hypercornargs = sum((shlex.split(x) for x in args.hypercorn), [])

# make app
global app

app = FastAPI()
create_conn_server(app, args.ipport)

import wsfwd
#_debprint(repr(wsfwd.app))
hypercornargs.append('%s:app' % __name__)
hypercorn_main(sys_args=hypercornargs)

return

parser.print_usage()
async def fun():
sys.exit(5)

return fun()

def wrap_connect(mockobj, readdata=b''):
reader = Mock()
reader.read = AsyncMock()
reader.read.side_effect = [ readdata, b'' ]

writer = Mock()
writer.drain = AsyncMock()
writer.wait_closed = AsyncMock()

mockobj.return_value = (reader, writer)

class TestCTW(unittest.TestCase):
def test_convert_to_ws(self):
self.assertEqual(convert_to_ws('http://foo/'), 'ws://foo/')
self.assertEqual(convert_to_ws('https://foo/'), 'wss://foo/')

class TestServer(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d

self.shutdown_event = asyncio.Event()

self.socketpath = os.path.join(self.basetempdir, 'wstest.sock')

config = Config()
config.graceful_timeout = .01
config.bind = [ 'unix:' + self.socketpath ]
config.loglevel = 'ERROR'
self.config = config

async def asyncTearDown(self):
self.app = None

self.shutdown_event.set()

await self.serv_task

shutil.rmtree(self.basetempdir)
self.basetempdir = None

@patch('asyncio.open_connection')
@timeout(2)
async def test_connect(self, oc):
connarg = '127.0.0.1:12345'

app = FastAPI()
create_conn_server(app, [ connarg ])

self.serv_task = asyncio.create_task(serve(app, self.config,
shutdown_trigger=self.shutdown_event.wait))

# get the unix domain socket connected
# need a startup_trigger
await asyncio.sleep(.01)

def wrapper(corofun):
async def foo(*args, **kwargs):
r = await corofun(*args, **kwargs)
#print('foo:', repr(corofun), repr((args, kwargs)), repr(r))
return r

return foo

async with websockets.unix_connect(self.socketpath,
'ws://foo/connect') as websocket, \
WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client:
mstdout = AsyncMock()

echodata = b'somedata'
wrap_connect(oc, readdata=echodata)

# that bad args are rejected
with self.assertRaises(RuntimeError):
await client.connect('192.168.0.1:12345')

with self.assertRaises(RuntimeError):
await client.connect('127.0.0.1:348')

client.add_stream_handler(2, mstdout)
reader, writer = await client.connect(connarg)

# that it cannot be connected'd a second time
with self.assertRaises(RuntimeError):
await client.connect(connarg)

writer.write(echodata)
await writer.drain()

# that we get our data
self.assertEqual(await reader.read(len(echodata)), echodata)

# and that there is no more
self.assertEqual(await reader.read(len(echodata)), b'')

# and we are truly at EOF
self.assertTrue(reader.at_eof())

writer.close()
await writer.wait_closed()

# make sure that it was called w/ the correct arugments
oc.assert_called_with(host='127.0.0.1', port=12345)

# spin things, not sure best way to handle this
await asyncio.sleep(.01)

# make sure that the writer was closed
oc.return_value[1].close.assert_called_with()

class TestMain(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.toclient = asyncio.Queue()
self.toserver = asyncio.Queue()

async def asyncTearDown(self):
self.assertTrue(self.toclient.empty())
self.assertTrue(self.toserver.empty())

def setup_websockets_mock(self, webcon):
conobj = Mock()

webcon().__aenter__.return_value = conobj

conobj.send = self.toserver.put
conobj.recv = self.toclient.get

webcon.reset_mock()

@contextlib.contextmanager
def make_pipe(self):
r, w = os.pipe()
with os.fdopen(r, 'rb', buffering=65536) as readfl, \
os.fdopen(w, 'wb', buffering=65536) as writefl:
yield readfl, writefl

# too lazy to make this async since async file-like objects
# aren't standard yet in Python
def copytask(self, reader, writer, doclose=True):
while True:
data = reader.read(16384)
#_debprint('ct', repr(reader), repr(writer), doclose, len(data))
if not data:
if doclose:
writer.close()
#_debprint('ct done', repr(reader), repr(writer), doclose)
return

writer.write(data)

@_atbprinter
async def runAsyncMain(self, fun=real_main, stdin=''):
# make stdin bytes
if isinstance(stdin, str):
stdin = stdin.encode()

# make stdout
stdout = io.BytesIO()

# Data path:
# stdin -> stdin_task -> stdinwriter -> pipe ->
# stdinreader -> sys.stdin -> real_main -> sys.stdout ->
# stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout
#
# How things get closed:
# stdin is already "closed", stdin_task will close
# stdinwriter when eof encountered (doclose).

# create the pipes needed
with self.make_pipe() as (stdinreader, stdinwriter), \
self.make_pipe() as (stdoutreader, stdoutwriter):
# setup the threads to move data
loop = asyncio.get_running_loop()
stdin_task = loop.run_in_executor(None, self.copytask,
io.BytesIO(stdin), stdinwriter)

# do not close stdout, otherwise we cannot call
# getvalue on BytesIO object
stdout_task = loop.run_in_executor(None, self.copytask,
stdoutreader, stdout, False)

#_debprint('ram', repr(stdout_task), repr(stdoutreader))

# insert the pipes
with patch.dict(sys.__dict__,
dict(stdin=TextIOWrapper(stdinreader),
stdout=TextIOWrapper(stdoutwriter))):
try:
# run the function
await fun()
#await asyncio.wait_for(fun(), 1)
ret = 0 #pragma: no cover
except SystemExit as e:
#_debprint('exit', repr(e.code))
ret = e.code

# No one to write anything anymore
# close so stdout_task will end
stdoutwriter.close()

# make sure all the data has been copied
await asyncio.gather(stdin_task, stdout_task)

stdoutvalue = stdout.getvalue()

return ret, stdoutvalue

def test_mainserver(self):
with open('/dev/null', 'w') as fp, patch.dict(sys.__dict__, dict(args=[ ], stderr=fp)), \
self.assertRaises(SystemExit) as context:
asyncio.run(real_main())

self.assertEqual(context.exception.code, 2)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
real_main()

hcm.assert_called_with(sys_args=[ 'wsfwd:app' ])

self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'serve', '--hypercorn', '-b 127.0.0.1', '127.0.0.1:12345', ])), patch('wsfwd.hypercorn_main') as hcm, patch('wsfwd.create_conn_server') as ccs:
real_main()

hcm.assert_called_with(sys_args=['-b', '127.0.0.1', 'wsfwd:app' ])

self.assertEqual(ccs.mock_calls[0][1][1], [ '127.0.0.1:12345' ])
self.assertIsInstance(ccs.mock_calls[0][1][0], FastAPI)

@timeout(2)
async def test_socket(self):
class TestServer(WSFWDCommon):
async def echo_handler(self, stream, msg):
self.sendstream(stream, msg)
await self.drain(stream)

async def handle_chanclose(self, msg):
self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='chanclose',
chan=self._stdout_stream))))

async def handle_connect(self, msg):
self._stdout_stream = msg['stdout']
assert msg['args'] == '127.0.0.1:38493'
self.add_stream_handler(msg['stdin'],
functools.partial(self.echo_handler,
msg['stdout']))

server = TestServer(self.toserver.get, self.toclient.put)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'connect',
'https://example.com/connectws', '127.0.0.1:38493', ])), \
patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon)

inpdata = bytes(range(0, 255))

ret, stdout = await self.runAsyncMain(stdin=inpdata)

await server.__aexit__(None, None, None)

self.assertEqual(stdout, inpdata)

self.assertEqual(ret, 0)

class Test(unittest.IsolatedAsyncioTestCase):
@staticmethod
def _encodecmd(payload):
@@ -474,7 +986,6 @@ class Test(unittest.IsolatedAsyncioTestCase):

@timeout(2)
async def test_authfail(self):
@_tbprinter
async def fake_server(reader, writer):
cmd = await reader()

@@ -508,7 +1019,6 @@ class Test(unittest.IsolatedAsyncioTestCase):
writerdata2 = b'oiwuersldkj'
readerdata1 = b'weoiusdofiuwe'

@_tbprinter
async def fake_server(reader, writer):
# get the auth command
auth = await reader()


+ 9
- 0
wsfwd/__main__.py View File

@@ -0,0 +1,9 @@
import asyncio

from wsfwd import real_main

def main():
asyncio.run(real_main())

if __name__ == '__main__': #pragma: no cover
main()

Loading…
Cancel
Save