| @@ -55,7 +55,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| async def negotiate(self, reader, writer): | async def negotiate(self, reader, writer): | ||||
| try: | try: | ||||
| req = self.socks_request(c.SOCKS_CMD_CONNECT) | |||||
| req = self.socks_request(self.cmd) | |||||
| self._proxy_peername, self._proxy_sockname = await req | self._proxy_peername, self._proxy_sockname = await req | ||||
| except SocksError as exc: | except SocksError as exc: | ||||
| exc = SocksError('Can not connect to %s:%s. %s' % | exc = SocksError('Can not connect to %s:%s. %s' % | ||||
| @@ -4,17 +4,48 @@ import socket | |||||
| from aiohttp.test_utils import unused_port | from aiohttp.test_utils import unused_port | ||||
| def _asyncsockpair(): | |||||
| '''Create a pair of sockets that are bound to each other. | |||||
| The function will return a tuple of two coroutine's, that | |||||
| each, when await'ed upon, will return the reader/writer pair.''' | |||||
| socka, sockb = socket.socketpair() | |||||
| return asyncio.open_connection(sock=socka), \ | |||||
| asyncio.open_connection(sock=sockb) | |||||
| async def _getreaderwriter(): | |||||
| '''Return a reader/writer pair. Any data written | |||||
| to the reader can be read from the writer side. | |||||
| returns (reader, writer).''' | |||||
| socka, sockb = _asyncsockpair() | |||||
| ardr, awrr = await socka | |||||
| brdr, bwrr = await sockb | |||||
| # don't close, as it also closes the reader as well | |||||
| awrr.write_eof() | |||||
| return ardr, bwrr | |||||
| class FakeSocksSrv: | class FakeSocksSrv: | ||||
| def __init__(self, loop, write_buff): | def __init__(self, loop, write_buff): | ||||
| self._loop = loop | self._loop = loop | ||||
| self._write_buff = write_buff | self._write_buff = write_buff | ||||
| self._transports = [] | self._transports = [] | ||||
| self._srv = None | self._srv = None | ||||
| self._pipes = None | |||||
| self.port = unused_port() | self.port = unused_port() | ||||
| def get_reader(self): | |||||
| return self._pipes[0] | |||||
| async def __aenter__(self): | async def __aenter__(self): | ||||
| transports = self._transports | transports = self._transports | ||||
| write_buff = self._write_buff | write_buff = self._write_buff | ||||
| pipes = await _getreaderwriter() | |||||
| self._pipes = pipes | |||||
| class SocksPrimitiveProtocol(asyncio.Protocol): | class SocksPrimitiveProtocol(asyncio.Protocol): | ||||
| _transport = None | _transport = None | ||||
| @@ -24,6 +55,7 @@ class FakeSocksSrv: | |||||
| transports.append(transport) | transports.append(transport) | ||||
| def data_received(self, data): | def data_received(self, data): | ||||
| pipes[1].write(data) | |||||
| self._transport.write(write_buff) | self._transport.write(write_buff) | ||||
| def factory(): | def factory(): | ||||
| @@ -39,7 +71,14 @@ class FakeSocksSrv: | |||||
| tr.close() | tr.close() | ||||
| self._srv.close() | self._srv.close() | ||||
| self._pipes[1].close() | |||||
| await self._srv.wait_closed() | await self._srv.wait_closed() | ||||
| try: | |||||
| await self._pipes[1].wait_closed() | |||||
| except Exception: | |||||
| pass | |||||
| class FakeSocks4Srv: | class FakeSocks4Srv: | ||||
| @@ -11,6 +11,7 @@ from aiohttp.test_utils import make_mocked_coro | |||||
| from aiosocks.test_utils import FakeSocksSrv, FakeSocks4Srv | from aiosocks.test_utils import FakeSocksSrv, FakeSocks4Srv | ||||
| from aiosocks.connector import ProxyConnector, ProxyClientRequest | from aiosocks.connector import ProxyConnector, ProxyClientRequest | ||||
| from aiosocks.errors import SocksConnectionError | from aiosocks.errors import SocksConnectionError | ||||
| from aiosocks import constants as c | |||||
| from async_timeout import timeout | from async_timeout import timeout | ||||
| from unittest import mock | from unittest import mock | ||||
| @@ -117,6 +118,7 @@ async def test_socks5_datagram_success_anonymous(): | |||||
| portnum = 53 | portnum = 53 | ||||
| dst = (dname, portnum) | dst = (dname, portnum) | ||||
| # Fake SOCKS server UDP relay | |||||
| class FakeDGramTransport(asyncio.DatagramTransport): | class FakeDGramTransport(asyncio.DatagramTransport): | ||||
| def sendto(self, data, addr=None): | def sendto(self, data, addr=None): | ||||
| # Verify correct packet was receieved | # Verify correct packet was receieved | ||||
| @@ -142,6 +144,7 @@ async def test_socks5_datagram_success_anonymous(): | |||||
| sockservdgram = FakeDGramTransport() | sockservdgram = FakeDGramTransport() | ||||
| # Fake the creation of the UDP relay | |||||
| async def fake_cde(factory, remote_addr): | async def fake_cde(factory, remote_addr): | ||||
| assert remote_addr == ('1.1.1.1', 1111) | assert remote_addr == ('1.1.1.1', 1111) | ||||
| @@ -151,10 +154,15 @@ async def test_socks5_datagram_success_anonymous(): | |||||
| return sockservdgram, proto | return sockservdgram, proto | ||||
| # Open the UDP connection | |||||
| with mock.patch.object(loop, 'create_datagram_endpoint', | with mock.patch.object(loop, 'create_datagram_endpoint', | ||||
| fake_cde) as m: | fake_cde) as m: | ||||
| dgram = await aiosocks.open_datagram(addr, None, dst, loop=loop) | dgram = await aiosocks.open_datagram(addr, None, dst, loop=loop) | ||||
| rdr = srv.get_reader() | |||||
| # make sure we negotiated the correct command | |||||
| assert (await rdr.readexactly(5))[4] == c.SOCKS_CMD_UDP_ASSOCIATE | |||||
| assert dgram.proxy_sockname == ('1.1.1.1', 1111) | assert dgram.proxy_sockname == ('1.1.1.1', 1111) | ||||
| dgram.send(b'some data') | dgram.send(b'some data') | ||||