diff --git a/wsfwd/__init__.py b/wsfwd/__init__.py index 50d1dd3..4244163 100644 --- a/wsfwd/__init__.py +++ b/wsfwd/__init__.py @@ -58,7 +58,7 @@ class WFProcess: self._stdin = stdin self._stdout = stdout self._returncode = None - self._retcode_waiters = [] + self._retcode_event = asyncio.Event() def set_returncode(self, retcode): ''' @@ -74,14 +74,9 @@ class WFProcess: self._returncode = retcode - waiters = self._retcode_waiters - - self._retcode_waiters = None - - for waiter in waiters: - if not waiter.cancelled(): - waiter.set_result(retcode) + self._retcode_event.set() + self._retcode_event = None @property def returncode(self): @@ -115,14 +110,8 @@ class WFProcess: if self.returncode is not None: return self.returncode - fut = asyncio.Future() - self._retcode_waiters.append(fut) - - try: - return await fut - except asyncio.CancelledError: - if self._retcode_waiters is not None: - self._retcode_waiters.remove(fut) + await self._retcode_event.wait() + return self.returncode class WFStreamWriter: '''This emulates asyncio.StreamWriter. For more info, see: @@ -136,6 +125,7 @@ class WFStreamWriter: self._client = client self._stream = stream self._closed = False + self._closed_event = asyncio.Event() def write(self, data): ''' @@ -178,6 +168,9 @@ class WFStreamWriter: self._closed = True + self._closed_event.set() + self._closed_event = None + self._closetask = asyncio.create_task(self._client._sendcmd(dict(cmd='chanclose', chan=self._stream))) async def wait_closed(self): @@ -185,7 +178,10 @@ class WFStreamWriter: https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.wait_closed ''' - pass + if self._closed: + return + + await self._closed_event.wait() class WSFWDCommon: def __init__(self, reader, writer): @@ -486,9 +482,21 @@ class Test(unittest.IsolatedAsyncioTestCase): # that the proc wait procwaittask = asyncio.create_task(proc.wait()) + # and when allowed to run + await asyncio.sleep(0) + # doesn't complete immediately self.assertFalse(procwaittask.done()) + # that the wait_closed + waitclosedtask = asyncio.create_task(writer.wait_closed()) + + # and when allowed to run + await asyncio.sleep(0) + + # doesn't complete immediately + self.assertFalse(waitclosedtask.done()) + writer.close() # that it fails on more data @@ -501,7 +509,8 @@ class Test(unittest.IsolatedAsyncioTestCase): self.assertTrue(writer.is_closing()) - await writer.wait_closed() + # but wait_closed does finally finish + await waitclosedtask self.assertEqual(await reader.read(), b'')