@@ -0,0 +1,222 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Net.Sockets; | |||||
using System.Runtime.CompilerServices; | |||||
using System.Text; | |||||
using System.Threading; | |||||
using System.Threading.Tasks; | |||||
namespace Shadowsocks.Controller | |||||
{ | |||||
// cache first packet for duty-chain pattern listener | |||||
public class CachedNetworkStream : Stream | |||||
{ | |||||
// 256 byte first packet buffer should enough for 99.999...% situation | |||||
// socks5: 0x05 0x.... | |||||
// http-pac: GET /pac HTTP/1.1 | |||||
// http-proxy: /[a-z]+ .+ HTTP\/1\.[01]/i | |||||
public const int MaxCache = 256; | |||||
public Socket Socket { get; private set; } | |||||
private readonly Stream s; | |||||
private byte[] cache = new byte[MaxCache]; | |||||
private long cachePtr = 0; | |||||
private long readPtr = 0; | |||||
public CachedNetworkStream(Socket socket) | |||||
{ | |||||
s = new NetworkStream(socket); | |||||
Socket = socket; | |||||
} | |||||
/// <summary> | |||||
/// Only for test purpose | |||||
/// </summary> | |||||
/// <param name="stream"></param> | |||||
public CachedNetworkStream(Stream stream) | |||||
{ | |||||
s = stream; | |||||
} | |||||
public override bool CanRead => s.CanRead; | |||||
// we haven't run out of cache | |||||
public override bool CanSeek => cachePtr == readPtr; | |||||
public override bool CanWrite => s.CanWrite; | |||||
public override long Length => s.Length; | |||||
public override long Position { get => readPtr; set => Seek(value, SeekOrigin.Begin); } | |||||
public override void Flush() | |||||
{ | |||||
s.Flush(); | |||||
} | |||||
//public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) | |||||
//{ | |||||
// var endPtr = buffer.Length + readPtr; // expected ptr after operation | |||||
// var uncachedLen = Math.Max(endPtr - cachePtr, 0); // how many data from socket | |||||
// var cachedLen = buffer.Length - uncachedLen; // how many data from cache | |||||
// var emptyCacheLen = MaxCache - cachePtr; // how many cache remain | |||||
// int readLen = 0; | |||||
// var cachedMem = buffer[..(int)cachedLen]; | |||||
// var uncachedMem = buffer[(int)cachedLen..]; | |||||
// if (cachedLen > 0) | |||||
// { | |||||
// cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedMem); | |||||
// readPtr += cachedLen; | |||||
// readLen += (int)cachedLen; | |||||
// } | |||||
// if (uncachedLen > 0) | |||||
// { | |||||
// int readStreamLen = await s.ReadAsync(cachedMem, cancellationToken); | |||||
// int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); // how many data need to cache | |||||
// if (lengthToCache > 0) | |||||
// { | |||||
// uncachedMem[0..lengthToCache].CopyTo(cache[(int)cachePtr..]); | |||||
// cachePtr += lengthToCache; | |||||
// } | |||||
// readPtr += readStreamLen; | |||||
// readLen += readStreamLen; | |||||
// } | |||||
// return readLen; | |||||
//} | |||||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
public override int Read(byte[] buffer, int offset, int count) | |||||
{ | |||||
Span<byte> span = buffer.AsSpan(offset, count); | |||||
return Read(span); | |||||
} | |||||
[MethodImpl(MethodImplOptions.AggressiveOptimization)] | |||||
public override int Read(Span<byte> buffer) | |||||
{ | |||||
// how many data from socket | |||||
// r: readPtr, c: cachePtr, e: endPtr | |||||
// ptr 0 r c e | |||||
// cached ####+++++ | |||||
// read ++++ | |||||
// ptr 0 c r e | |||||
// cached ##### | |||||
// read +++++ | |||||
var endPtr = buffer.Length + readPtr; // expected ptr after operation | |||||
var uncachedLen = Math.Max(endPtr - Math.Max(cachePtr, readPtr), 0); | |||||
var cachedLen = buffer.Length - uncachedLen; // how many data from cache | |||||
var emptyCacheLen = MaxCache - cachePtr; // how many cache remain | |||||
int readLen = 0; | |||||
Span<byte> cachedSpan = buffer[..(int)cachedLen]; | |||||
Span<byte> uncachedSpan = buffer[(int)cachedLen..]; | |||||
if (cachedLen > 0) | |||||
{ | |||||
cache[(int)readPtr..(int)(readPtr + cachedLen)].CopyTo(cachedSpan); | |||||
readPtr += cachedLen; | |||||
readLen += (int)cachedLen; | |||||
} | |||||
if (uncachedLen > 0) | |||||
{ | |||||
int readStreamLen = s.Read(uncachedSpan); | |||||
// how many data need to cache | |||||
int lengthToCache = (int)Math.Min(emptyCacheLen, readStreamLen); | |||||
if (lengthToCache > 0) | |||||
{ | |||||
uncachedSpan[0..lengthToCache].ToArray().CopyTo(cache, cachePtr); | |||||
cachePtr += lengthToCache; | |||||
} | |||||
readPtr += readStreamLen; | |||||
readLen += readStreamLen; | |||||
} | |||||
return readLen; | |||||
} | |||||
/// <summary> | |||||
/// Read first block, will never read into non-cache range | |||||
/// </summary> | |||||
/// <param name="buffer"></param> | |||||
/// <returns></returns> | |||||
public int ReadFirstBlock(Span<byte> buffer) | |||||
{ | |||||
Seek(0, SeekOrigin.Begin); | |||||
int len = Math.Min(MaxCache, buffer.Length); | |||||
return Read(buffer[0..len]); | |||||
} | |||||
/// <summary> | |||||
/// Seek position, only support seek to cached range when we haven't read into non-cache range | |||||
/// </summary> | |||||
/// <param name="offset"></param> | |||||
/// <param name="origin">Set it to System.IO.SeekOrigin.Begin, otherwise it will throw System.NotSupportedException</param> | |||||
/// <exception cref="IOException"></exception> | |||||
/// <exception cref="NotSupportedException"></exception> | |||||
/// <exception cref="ObjectDisposedException"></exception> | |||||
/// <returns></returns> | |||||
public override long Seek(long offset, SeekOrigin origin) | |||||
{ | |||||
if (!CanSeek) throw new NotSupportedException("Non cache data has been read"); | |||||
if (origin != SeekOrigin.Begin) throw new NotSupportedException("We don't know network stream's length"); | |||||
if (offset < 0 || offset > cachePtr) throw new NotSupportedException("Can't seek to uncached position"); | |||||
readPtr = offset; | |||||
return Position; | |||||
} | |||||
/// <summary> | |||||
/// Useless | |||||
/// </summary> | |||||
/// <param name="value"></param> | |||||
/// <exception cref="NotSupportedException"></exception> | |||||
public override void SetLength(long value) | |||||
{ | |||||
s.SetLength(value); | |||||
} | |||||
/// <summary> | |||||
/// Write to underly stream | |||||
/// </summary> | |||||
/// <param name="buffer"></param> | |||||
/// <param name="offset"></param> | |||||
/// <param name="count"></param> | |||||
/// <param name="cancellationToken"></param> | |||||
/// <returns></returns> | |||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) | |||||
{ | |||||
return s.WriteAsync(buffer, offset, count, cancellationToken); | |||||
} | |||||
/// <summary> | |||||
/// Write to underly stream | |||||
/// </summary> | |||||
/// <param name="buffer"></param> | |||||
/// <param name="offset"></param> | |||||
/// <param name="count"></param> | |||||
public override void Write(byte[] buffer, int offset, int count) | |||||
{ | |||||
s.Write(buffer, offset, count); | |||||
} | |||||
protected override void Dispose(bool disposing) | |||||
{ | |||||
s.Dispose(); | |||||
base.Dispose(disposing); | |||||
} | |||||
} | |||||
} |
@@ -24,6 +24,8 @@ namespace Shadowsocks.Controller | |||||
{ | { | ||||
public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state); | public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state); | ||||
public abstract bool Handle(CachedNetworkStream stream, object state); | |||||
public virtual void Stop() { } | public virtual void Stop() { } | ||||
} | } | ||||
@@ -53,6 +53,13 @@ namespace Shadowsocks.Controller | |||||
return HttpServerUtilityUrlToken.Encode(CryptoUtils.MD5(Encoding.ASCII.GetBytes(content))); | return HttpServerUtilityUrlToken.Encode(CryptoUtils.MD5(Encoding.ASCII.GetBytes(content))); | ||||
} | } | ||||
public override bool Handle(CachedNetworkStream stream, object state) | |||||
{ | |||||
byte[] fp = new byte[256]; | |||||
int len = stream.ReadFirstBlock(fp); | |||||
return Handle(fp, len, stream.Socket, state); | |||||
} | |||||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | ||||
{ | { | ||||
if (socket.ProtocolType != ProtocolType.Tcp) | if (socket.ProtocolType != ProtocolType.Tcp) | ||||
@@ -154,8 +161,6 @@ namespace Shadowsocks.Controller | |||||
} | } | ||||
} | } | ||||
public void SendResponse(Socket socket, bool useSocks) | public void SendResponse(Socket socket, bool useSocks) | ||||
{ | { | ||||
try | try | ||||
@@ -195,7 +200,6 @@ Connection: Close | |||||
{ } | { } | ||||
} | } | ||||
private string GetPACAddress(IPEndPoint localEndPoint, bool useSocks) | private string GetPACAddress(IPEndPoint localEndPoint, bool useSocks) | ||||
{ | { | ||||
return localEndPoint.AddressFamily == AddressFamily.InterNetworkV6 | return localEndPoint.AddressFamily == AddressFamily.InterNetworkV6 | ||||
@@ -15,6 +15,13 @@ namespace Shadowsocks.Controller | |||||
_targetPort = targetPort; | _targetPort = targetPort; | ||||
} | } | ||||
public override bool Handle(CachedNetworkStream stream, object state) | |||||
{ | |||||
byte[] fp = new byte[256]; | |||||
int len = stream.ReadFirstBlock(fp); | |||||
return Handle(fp, len, stream.Socket, state); | |||||
} | |||||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | ||||
{ | { | ||||
if (socket.ProtocolType != ProtocolType.Tcp) | if (socket.ProtocolType != ProtocolType.Tcp) | ||||
@@ -33,6 +33,13 @@ namespace Shadowsocks.Controller | |||||
_lastSweepTime = DateTime.Now; | _lastSweepTime = DateTime.Now; | ||||
} | } | ||||
public override bool Handle(CachedNetworkStream stream, object state) | |||||
{ | |||||
byte[] fp = new byte[256]; | |||||
int len = stream.ReadFirstBlock(fp); | |||||
return Handle(fp, len, stream.Socket, state); | |||||
} | |||||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | ||||
{ | { | ||||
if (socket.ProtocolType != ProtocolType.Tcp | if (socket.ProtocolType != ProtocolType.Tcp | ||||
@@ -25,6 +25,13 @@ namespace Shadowsocks.Controller | |||||
this._controller = controller; | this._controller = controller; | ||||
} | } | ||||
public override bool Handle(CachedNetworkStream stream, object state) | |||||
{ | |||||
byte[] fp = new byte[256]; | |||||
int len = stream.ReadFirstBlock(fp); | |||||
return Handle(fp, len, stream.Socket, state); | |||||
} | |||||
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) | ||||
{ | { | ||||
if (socket.ProtocolType != ProtocolType.Udp) | if (socket.ProtocolType != ProtocolType.Udp) | ||||
@@ -0,0 +1,84 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Text; | |||||
using Shadowsocks.Controller; | |||||
namespace Shadowsocks.Test | |||||
{ | |||||
[TestClass] | |||||
public class CachedNetworkStreamTest | |||||
{ | |||||
byte[] b0 = new byte[256]; | |||||
byte[] b1 = new byte[256]; | |||||
byte[] b2 = new byte[1024]; | |||||
// [TestInitialize] | |||||
[TestInitialize] | |||||
public void init() | |||||
{ | |||||
for (int i = 0; i < 256; i++) | |||||
{ | |||||
b0[i] = (byte)i; | |||||
b1[i] = (byte)(255 - i); | |||||
} | |||||
b0.CopyTo(b2, 0); | |||||
b1.CopyTo(b2, 256); | |||||
b0.CopyTo(b2, 512); | |||||
} | |||||
[TestMethod] | |||||
public void StreamTest() | |||||
{ | |||||
using MemoryStream ms = new MemoryStream(b2); | |||||
using CachedNetworkStream s = new CachedNetworkStream(ms); | |||||
byte[] o = new byte[128]; | |||||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||||
TestUtils.ArrayEqual(b0[0..128], o); | |||||
Assert.AreEqual(64, s.Read(o, 0, 64)); | |||||
TestUtils.ArrayEqual(b0[128..192], o[0..64]); | |||||
s.Seek(0, SeekOrigin.Begin); | |||||
Assert.AreEqual(64, s.Read(o, 0, 64)); | |||||
TestUtils.ArrayEqual(b0[0..64], o[0..64]); | |||||
// refuse to go out of cached range | |||||
Assert.ThrowsException<NotSupportedException>(() => | |||||
{ | |||||
s.Seek(193, SeekOrigin.Begin); | |||||
}); | |||||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||||
TestUtils.ArrayEqual(b0[64..192], o); | |||||
Assert.IsTrue(s.CanSeek); | |||||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||||
TestUtils.ArrayEqual(b0[192..256], o[0..64]); | |||||
TestUtils.ArrayEqual(b1[0..64], o[64..128]); | |||||
Assert.IsFalse(s.CanSeek); | |||||
// refuse to go back when non-cache data has been read | |||||
Assert.ThrowsException<NotSupportedException>(() => | |||||
{ | |||||
s.Seek(0, SeekOrigin.Begin); | |||||
}); | |||||
// read in non-cache range | |||||
Assert.AreEqual(64, s.Read(o, 0, 64)); | |||||
s.Read(o, 0, 128); | |||||
Assert.AreEqual(512, s.Position); | |||||
Assert.AreEqual(128, s.Read(o, 0, 128)); | |||||
TestUtils.ArrayEqual(b0[0..128], o); | |||||
s.Read(o, 0, 128); | |||||
s.Read(o, 0, 128); | |||||
s.Read(o, 0, 128); | |||||
// read at eos | |||||
Assert.AreEqual(0, s.Read(o, 0, 128)); | |||||
} | |||||
} | |||||
} |
@@ -42,7 +42,7 @@ namespace Shadowsocks.Test | |||||
//encryptor.Encrypt(plain, length, cipher, out int outLen); | //encryptor.Encrypt(plain, length, cipher, out int outLen); | ||||
//decryptor.Decrypt(cipher, outLen, plain2, out int outLen2); | //decryptor.Decrypt(cipher, outLen, plain2, out int outLen2); | ||||
Assert.AreEqual(length, outLen2); | Assert.AreEqual(length, outLen2); | ||||
ArrayEqual<byte>(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray()); | |||||
TestUtils.ArrayEqual<byte>(plain.AsSpan(0, length).ToArray(), plain2.AsSpan(0, length).ToArray()); | |||||
} | } | ||||
const string password = "barfoo!"; | const string password = "barfoo!"; | ||||
@@ -70,36 +70,6 @@ namespace Shadowsocks.Test | |||||
throw; | throw; | ||||
} | } | ||||
} | } | ||||
private void ArrayEqual<T>(IEnumerable<T> expected, IEnumerable<T> actual, string msg = "") | |||||
{ | |||||
var e1 = expected.GetEnumerator(); | |||||
var e2 = actual.GetEnumerator(); | |||||
int ctr = 0; | |||||
while (true) | |||||
{ | |||||
var e1next = e1.MoveNext(); | |||||
var e2next = e2.MoveNext(); | |||||
if (e1next && e2next) | |||||
{ | |||||
Assert.AreEqual(e1.Current, e2.Current, "at " + ctr); | |||||
} | |||||
else if (!e1next && !e2next) | |||||
{ | |||||
return; | |||||
} | |||||
else if (!e1next) | |||||
{ | |||||
Assert.Fail($"actual longer than expected ({ctr}) {msg}"); | |||||
} | |||||
else | |||||
{ | |||||
Assert.Fail($"actual shorter than expected ({ctr}) {msg}"); | |||||
} | |||||
} | |||||
} | |||||
private static bool encryptionFailed = false; | private static bool encryptionFailed = false; | ||||
private void TestEncryptionMethod(Type enc, string method) | private void TestEncryptionMethod(Type enc, string method) | ||||
@@ -4,6 +4,8 @@ | |||||
<TargetFramework>netcoreapp3.1</TargetFramework> | <TargetFramework>netcoreapp3.1</TargetFramework> | ||||
<IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
<RootNamespace>Shadowsocks.Test</RootNamespace> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -0,0 +1,41 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Shadowsocks.Test | |||||
{ | |||||
class TestUtils | |||||
{ | |||||
public static void ArrayEqual<T>(IEnumerable<T> expected, IEnumerable<T> actual, string msg = "") | |||||
{ | |||||
var e1 = expected.GetEnumerator(); | |||||
var e2 = actual.GetEnumerator(); | |||||
int ctr = 0; | |||||
while (true) | |||||
{ | |||||
var e1next = e1.MoveNext(); | |||||
var e2next = e2.MoveNext(); | |||||
if (e1next && e2next) | |||||
{ | |||||
Assert.AreEqual(e1.Current, e2.Current, "at " + ctr); | |||||
} | |||||
else if (!e1next && !e2next) | |||||
{ | |||||
return; | |||||
} | |||||
else if (!e1next) | |||||
{ | |||||
Assert.Fail($"actual longer than expected ({ctr}) {msg}"); | |||||
} | |||||
else | |||||
{ | |||||
Assert.Fail($"actual shorter than expected ({ctr}) {msg}"); | |||||
} | |||||
ctr++; | |||||
} | |||||
} | |||||
} | |||||
} |