From e8a2b7cfbaf8e96131e069580bfb4415ca709b21 Mon Sep 17 00:00:00 2001 From: nibrag Date: Fri, 3 Jun 2016 16:45:22 +0300 Subject: [PATCH] Fix: some socks5 severs expect fully-formed command request --- aiosocks/protocols.py | 14 ++++++------- tests/test_protocols.py | 45 ++++++++++++++++++++++------------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/aiosocks/protocols.py b/aiosocks/protocols.py index 4038aef..89cf757 100644 --- a/aiosocks/protocols.py +++ b/aiosocks/protocols.py @@ -284,9 +284,9 @@ class Socks5Protocol(BaseSocksProtocol): yield from self.authenticate() # build and send command - self.write_request([c.SOCKS_VER5, cmd, c.RSV]) - resolved = yield from self.write_address(self._dst_host, - self._dst_port) + dst_addr, resolved = yield from self.build_dst_address( + self._dst_host, self._dst_port) + self.write_request([c.SOCKS_VER5, cmd, c.RSV] + dst_addr) # read/process command response resp = yield from self.read_response(3) @@ -348,7 +348,7 @@ class Socks5Protocol(BaseSocksProtocol): ) @asyncio.coroutine - def write_address(self, host, port): + def build_dst_address(self, host, port): family_to_byte = {socket.AF_INET: c.SOCKS5_ATYP_IPv4, socket.AF_INET6: c.SOCKS5_ATYP_IPv6} port_bytes = struct.pack('>H', port) @@ -359,8 +359,7 @@ class Socks5Protocol(BaseSocksProtocol): try: host_bytes = socket.inet_pton(family, host) req = [family_to_byte[family], host_bytes, port_bytes] - self.write_request(req) - return host, port + return req, (host, port) except socket.error: pass @@ -375,8 +374,7 @@ class Socks5Protocol(BaseSocksProtocol): req = [family_to_byte[family], host_bytes, port_bytes] host = socket.inet_ntop(family, host_bytes) - self.write_request(req) - return host, port + return req, (host, port) @asyncio.coroutine def read_address(self): diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 4b2d8ab..c6b0226 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -53,6 +53,7 @@ def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop, app_protocol_factory=ap_factory, waiter=whiter) proto._stream_writer = mock.Mock() + proto._stream_writer.drain = fake_coroutine(True) if not isinstance(r, (list, tuple)): proto.read_response = mock.Mock( @@ -526,37 +527,40 @@ class TestSocks5Protocol(unittest.TestCase): req = proto.authenticate() self.loop.run_until_complete(req) - def test_wr_addr_ipv4(self): + def test_build_dst_addr_ipv4(self): proto = make_socks5(self.loop) - req = proto.write_address('127.0.0.1', 80) - self.loop.run_until_complete(req) + c = proto.build_dst_address('127.0.0.1', 80) + dst_req, resolved = self.loop.run_until_complete(c) - proto._stream_writer.write.assert_called_with( - b'\x01\x7f\x00\x00\x01\x00P') + self.assertEqual(dst_req, [0x01, b'\x7f\x00\x00\x01', b'\x00P']) + self.assertEqual(resolved, ('127.0.0.1', 80)) - def test_wr_addr_ipv6(self): + def test_build_dst_addr_ipv6(self): proto = make_socks5(self.loop) - req = proto.write_address( + c = proto.build_dst_address( '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80) - self.loop.run_until_complete(req) + dst_req, resolved = self.loop.run_until_complete(c) - proto._stream_writer.write.assert_called_with( - b'\x04 \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]\x00P') + self.assertEqual(dst_req, [ + 0x04, b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', b'\x00P']) + self.assertEqual(resolved, + ('2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)) - def test_wr_addr_domain_with_remote_resolve(self): + def test_build_dst_addr_domain_with_remote_resolve(self): proto = make_socks5(self.loop) - req = proto.write_address('python.org', 80) - self.loop.run_until_complete(req) + c = proto.build_dst_address('python.org', 80) + dst_req, resolved = self.loop.run_until_complete(c) - proto._stream_writer.write.assert_called_with(b'\x03\npython.org\x00P') + self.assertEqual(dst_req, [0x03, b'\n', b'python.org', b'\x00P']) + self.assertEqual(resolved, ('python.org', 80)) - def test_wr_addr_domain_with_locale_resolve(self): + def test_build_dst_addr_domain_with_locale_resolve(self): proto = make_socks5(self.loop, rr=False) - req = proto.write_address('python.org', 80) - self.loop.run_until_complete(req) + c = proto.build_dst_address('python.org', 80) + dst_req, resolved = self.loop.run_until_complete(c) - proto._stream_writer.write.assert_called_with( - b'\x01\x7f\x00\x00\x01\x00P') + self.assertEqual(dst_req, [0x01, b'\x7f\x00\x00\x01', b'\x00P']) + self.assertEqual(resolved, ('127.0.0.1', 80)) def test_rd_addr_ipv4(self): proto = make_socks5( @@ -624,6 +628,5 @@ class TestSocks5Protocol(unittest.TestCase): self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80))) proto._stream_writer.write.assert_has_calls([ mock.call(b'\x05\x02\x00\x02'), - mock.call(b'\x05\x01\x00'), - mock.call(b'\x03\npython.org\x00P') + mock.call(b'\x05\x01\x00\x03\npython.org\x00P') ])