Browse Source

implement the server side of things...

main
John-Mark Gurney 5 years ago
parent
commit
75ceac2b83
1 changed files with 157 additions and 41 deletions
  1. +157
    -41
      wsfwd/__init__.py

+ 157
- 41
wsfwd/__init__.py View File

@@ -17,6 +17,19 @@ def timeout(timeout):

return timeout_wrapper

def _tbprinter(fun): #pragma: no cover
@functools.wraps(fun)
async def wrapper(*args, **kwargs):
try:
return await fun(*args, **kwargs)
except Exception:
import traceback
print('in tbprinter:', repr(fun))
traceback.print_exc()
raise

return wrapper

class TestTimeout(unittest.IsolatedAsyncioTestCase):
async def test_timeout(self):
@timeout(.001)
@@ -112,7 +125,74 @@ class WFStreamWriter:
async def wait_closed(self):
pass

class WSFWDClient:
class WSFWDCommon:
def __init__(self, reader, writer):
self._reader = reader
self._writer = writer
self._task = asyncio.create_task(self._process_msgs())

# this contains enqueued outbound data for draining
self._streams = dict()

# dispatch incoming messages
self._procmsgs = dict()

def add_stream_handler(self, stream, hndlr):
# XXX - make sure we don't overwrite an existing one
self._procmsgs[stream] = hndlr

async def _process_msgs(self):
while True:
msg = await self._reader()
#print('got:', repr(msg))
stream = msg[0]
await self._procmsgs[stream](msg[1:])

def sendstream(self, stream, *data):
if not all(isinstance(x, bytes) for x in data):
raise ValueError('write data must be bytes')

self._streams.setdefault(stream, []).extend(data)

async def drain(self, stream):
datalist = self._streams[stream]
data = datalist.copy()

# clear the data
datalist[:] = []

senddata = b''.join(data)

if senddata:
await self._writer(stream.to_bytes(1, 'big') + senddata)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
self._task.cancel()

async def _sendcmd(self, cmd):
await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8'))

class WSFWDServer(WSFWDCommon):
def __init__(self, reader, writer):
super().__init__(reader, writer)

self.add_stream_handler(0, self._proccmds)

async def _proccmds(self, msg):
msg = json.loads(msg)
handler = getattr(self, 'handle_%s' % msg['cmd'])
try:
resp = await handler(msg)
resp = dict(resp=msg['cmd'])
except RuntimeError as e:
resp = dict(resp=msg['cmd'], error=e.args[0])

await self._sendcmd(resp)

class WSFWDClient(WSFWDCommon):
def __init__(self, reader, writer):
'''This is the client for doing command execution over
a datagram protocol, such as WebSockets. The two
@@ -125,29 +205,17 @@ class WSFWDClient:
be sent to the server.
'''

self._reader = reader
self._writer = writer
self._task = asyncio.create_task(self._process_msgs())

# this contains enqueued outbound data for draining
self._streams = dict()
super().__init__(reader, writer)

self._cmdq = asyncio.Queue()

# dispatch incoming messages
self._procmsgs = dict()
self._procmsgs[0] = self._process_cmdmsg

# wait for msgs
self._fetchpend = dict()
self._fetchpend[0] = self._cmdq.get

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
self._task.cancel()

async def _process_msgs(self):
while True:
msg = await self._reader()
@@ -166,29 +234,8 @@ class WSFWDClient:

await self._cmdq.put(msg)

async def _sendcmd(self, cmd):
await self._writer(b'\x00' + json.dumps(cmd).encode('utf-8'))

def sendstream(self, stream, *data):
if not all(isinstance(x, bytes) for x in data):
raise ValueError('write data must be bytes')

self._streams.setdefault(stream, []).extend(data)

async def drain(self, stream):
datalist = self._streams[stream]
data = datalist.copy()

# clear the data
datalist[:] = []

senddata = b''.join(data)

if senddata:
await self._writer(stream.to_bytes(1, 'big') + senddata)

async def auth(self, auth):
await self._sendcmd(dict(auth=auth))
await self._sendcmd(dict(cmd='auth', auth=auth))

rsp = await self._fetchpend[0]()

@@ -247,10 +294,9 @@ class Test(unittest.IsolatedAsyncioTestCase):

serv_task = self.runFakeServer(fake_server)

a = self.runClient()

with self.assertRaises(RuntimeError):
await a.auth('randomtoken')
async with self.runClient() as a:
with self.assertRaises(RuntimeError):
await a.auth('randomtoken')

await serv_task

@@ -261,7 +307,7 @@ class Test(unittest.IsolatedAsyncioTestCase):
token = 'sdlfkjsoidfjl'

authdict = dict(bearer=token)
authmsg = { 'auth': authdict }
authmsg = { 'cmd': 'auth', 'auth': authdict }
execargs = [ 'sldkfj', 'oweijfls' ]
writerdata1 = b'asoijeflksdjf'
writerdata2 = b'oiwuersldkj'
@@ -417,3 +463,73 @@ class Test(unittest.IsolatedAsyncioTestCase):

# That the process task does get terminated
self.assertTrue(proctask.done())

@timeout(2)
async def test_client_server(self):
token = 'sdlfkjoijef'
authdict = dict(bearer=token)
badauthdict = dict(bearer=token + 'eofij')

class TestWSFDServer(WSFWDServer):
async def handle_auth(self, msg):
if msg['auth'] == authdict:
return

raise RuntimeError('Invalid auth')

async def echo_handler(self, stream, msg):
self.sendstream(stream, msg)
await self.drain(stream)

async def handle_exec(self, msg):
self.add_stream_handler(msg['stdin'],
functools.partial(self.echo_handler,
msg['stdout']))

async def closehandler(ccmsg):
if ccmsg['chan'] == msg['stdin']:
await self.drain(msg['stdout'])
await self._sendcmd(dict(cmd='chanclose', chan=msg['stdout']))
await self._sendcmd(dict(cmd='exit', code=0))

# XXX - not happy w/ this api
self.handle_chanclose = closehandler

server = TestWSFDServer(self.toserver.get, self.toclient.put)

async with self.runClient() as a:
with self.assertRaises(RuntimeError) as cm:
await a.auth(badauthdict)

re = cm.exception
self.assertEqual(re.args[0], "Got auth error: 'Invalid auth'")

await a.auth(authdict)

proc = await a.exec(args=['echo'])

stdin, stdout = proc.stdin, proc.stdout

data = b'seoiadsujflaksdfj'
stdin.write(data)
await stdin.drain()

r = await stdout.read(len(data))

self.assertEqual(r, data)

data = data[::-1] # reverse data

stdin.write(data)
await stdin.drain()

r = await stdout.read(len(data))

self.assertEqual(r, data)

stdin.close()
await stdin.wait_closed()

self.assertEqual(await stdout.read(), b'')

self.assertEqual(await proc.wait(), 0)

Loading…
Cancel
Save