diff --git a/shadowsocks-csharp/Controller/Service/PortForwarder.cs b/shadowsocks-csharp/Controller/Service/PortForwarder.cs index f76a1284..dcac75bf 100644 --- a/shadowsocks-csharp/Controller/Service/PortForwarder.cs +++ b/shadowsocks-csharp/Controller/Service/PortForwarder.cs @@ -48,12 +48,8 @@ namespace Shadowsocks.Controller { EndPoint remoteEP = SocketUtil.GetEndPoint("localhost", targetPort); - _remote = SocketUtil.CreateSocket(remoteEP); - _remote.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); - // Connect to the remote endpoint. - _remote.BeginConnect(remoteEP, - new AsyncCallback(ConnectCallback), null); + SocketUtil.BeginConnectTcp(remoteEP, ConnectCallback, null); } catch (Exception e) { @@ -70,7 +66,7 @@ namespace Shadowsocks.Controller } try { - _remote.EndConnect(ar); + _remote = SocketUtil.EndConnectTcp(ar); HandshakeReceive(); } catch (Exception e) diff --git a/shadowsocks-csharp/Proxy/DirectConnect.cs b/shadowsocks-csharp/Proxy/DirectConnect.cs index 487adc41..76bdbc77 100644 --- a/shadowsocks-csharp/Proxy/DirectConnect.cs +++ b/shadowsocks-csharp/Proxy/DirectConnect.cs @@ -55,17 +55,12 @@ namespace Shadowsocks.Proxy { DestEndPoint = destEndPoint; - if (_remote == null) - { - _remote = SocketUtil.CreateSocket(destEndPoint); - _remote.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); - } - _remote.BeginConnect(destEndPoint, callback, state); + SocketUtil.BeginConnectTcp(destEndPoint, callback, state); } public void EndConnectDest(IAsyncResult asyncResult) { - _remote?.EndConnect(asyncResult); + _remote = SocketUtil.EndConnectTcp(asyncResult); } public void BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback callback, diff --git a/shadowsocks-csharp/Proxy/Socks5Proxy.cs b/shadowsocks-csharp/Proxy/Socks5Proxy.cs index adb94402..4ee6917c 100644 --- a/shadowsocks-csharp/Proxy/Socks5Proxy.cs +++ b/shadowsocks-csharp/Proxy/Socks5Proxy.cs @@ -52,16 +52,13 @@ namespace Shadowsocks.Proxy public void BeginConnectProxy(EndPoint remoteEP, AsyncCallback callback, object state) { - _remote = SocketUtil.CreateSocket(remoteEP); - _remote.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); - var st = new Socks5State(); st.Callback = callback; st.AsyncState = state; ProxyEndPoint = remoteEP; - _remote.BeginConnect(remoteEP, ConnectCallback, st); + SocketUtil.BeginConnectTcp(remoteEP, ConnectCallback, st); } public void EndConnectProxy(IAsyncResult asyncResult) @@ -180,7 +177,7 @@ namespace Shadowsocks.Proxy var state = (Socks5State) ar.AsyncState; try { - _remote.EndConnect(ar); + _remote = SocketUtil.EndConnectTcp(ar); byte[] handshake = {5, 1, 0}; _remote.BeginSend(handshake, 0, handshake.Length, 0, Socks5HandshakeSendCallback, state); diff --git a/shadowsocks-csharp/Util/SocketUtil.cs b/shadowsocks-csharp/Util/SocketUtil.cs index 46e18aa1..d7543b5b 100644 --- a/shadowsocks-csharp/Util/SocketUtil.cs +++ b/shadowsocks-csharp/Util/SocketUtil.cs @@ -1,6 +1,7 @@ using System; using System.Net; using System.Net.Sockets; +using System.Threading; namespace Shadowsocks.Util { @@ -35,33 +36,68 @@ namespace Shadowsocks.Util return new DnsEndPoint2(host, port); } - public static Socket CreateSocket(EndPoint endPoint, ProtocolType protocolType = ProtocolType.Tcp) + private class TcpUserToken : IAsyncResult { - SocketType socketType; - switch (protocolType) + public AsyncCallback Callback { get; } + public SocketAsyncEventArgs Args { get; } + + public TcpUserToken(AsyncCallback callback, object state, SocketAsyncEventArgs args) { - case ProtocolType.Tcp: - socketType = SocketType.Stream; - break; - case ProtocolType.Udp: - socketType = SocketType.Dgram; - break; - default: - throw new NotSupportedException("Protocol " + protocolType + " doesn't supported!"); + Callback = callback; + AsyncState = state; + Args = args; } - if (endPoint is DnsEndPoint) - { - // use dual-mode socket - var socket = new Socket(AddressFamily.InterNetworkV6, socketType, protocolType); - socket.SetSocketOption(SocketOptionLevel.IPv6, (SocketOptionName)27, false); + public bool IsCompleted { get; } = true; + public WaitHandle AsyncWaitHandle { get; } = null; + public object AsyncState { get; } + public bool CompletedSynchronously { get; } = true; + } + + private static void OnTcpConnectCompleted(object sender, SocketAsyncEventArgs args) + { + TcpUserToken token = (TcpUserToken) args.UserToken; + + token.Callback(token); + } + + public static void BeginConnectTcp(EndPoint endPoint, AsyncCallback callback, object state) + { + var arg = new SocketAsyncEventArgs(); + arg.RemoteEndPoint = endPoint; + arg.Completed += OnTcpConnectCompleted; + arg.UserToken = new TcpUserToken(callback, state, arg); + - return socket; + Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, arg); + } + + public static Socket EndConnectTcp(IAsyncResult asyncResult) + { + var tut = asyncResult as TcpUserToken; + if (tut == null) + { + throw new ArgumentException("Invalid asyncResult.", nameof(asyncResult)); } - else + + var arg = tut.Args; + + if (arg.SocketError != SocketError.Success) { - return new Socket(endPoint.AddressFamily, socketType, protocolType); + if (arg.ConnectByNameError != null) + { + throw arg.ConnectByNameError; + } + + var ex = new SocketException((int)arg.SocketError); + throw ex; } + + var so = tut.Args.ConnectSocket; + + so.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); + + return so; } } }