Browse Source

switch to CachedNetworkStream based api, fix AEADEncrypt buffer management

pull/2865/head
Student Main 5 years ago
parent
commit
924cde77e2
7 changed files with 86 additions and 18 deletions
  1. +12
    -3
      shadowsocks-csharp/Controller/LoggerExtension.cs
  2. +4
    -0
      shadowsocks-csharp/Controller/Service/Listener.cs
  3. +1
    -0
      shadowsocks-csharp/Controller/Service/PACServer.cs
  4. +8
    -1
      shadowsocks-csharp/Controller/Service/PortForwarder.cs
  5. +42
    -1
      shadowsocks-csharp/Controller/Service/TCPRelay.cs
  6. +1
    -0
      shadowsocks-csharp/Controller/Service/UDPRelay.cs
  7. +18
    -13
      shadowsocks-csharp/Encryption/AEAD/AEADEncryptor.cs

+ 12
- 3
shadowsocks-csharp/Controller/LoggerExtension.cs View File

@@ -18,10 +18,15 @@ namespace NLog
} }
public static void Dump(this Logger logger, string tag, byte[] arr, int length = -1) public static void Dump(this Logger logger, string tag, byte[] arr, int length = -1)
{ {
if (arr == null) logger.Trace($@"
{tag}:
(null)
");
if (length == -1) length = arr.Length; if (length == -1) length = arr.Length;
if (!logger.IsTraceEnabled) return; if (!logger.IsTraceEnabled) return;
string hex = BitConverter.ToString(arr.AsSpan(0, length).ToArray()).Replace("-", "");
string hex = BitConverter.ToString(arr.AsSpan(0, Math.Min(arr.Length, length)).ToArray()).Replace("-", "");
string content = $@" string content = $@"
{tag}: {tag}:
{hex} {hex}
@@ -36,10 +41,15 @@ namespace NLog
} }
public static void DumpBase64(this Logger logger, string tag, byte[] arr, int length = -1) public static void DumpBase64(this Logger logger, string tag, byte[] arr, int length = -1)
{ {
if (arr == null) logger.Trace($@"
{tag}:
(null)
");
if (length == -1) length = arr.Length; if (length == -1) length = arr.Length;
if (!logger.IsTraceEnabled) return; if (!logger.IsTraceEnabled) return;
string hex =Convert.ToBase64String(arr.AsSpan(0, length).ToArray());
string hex = Convert.ToBase64String(arr.AsSpan(0, Math.Min(arr.Length, length)).ToArray());
string content = $@" string content = $@"
{tag}: {tag}:
{hex} {hex}
@@ -48,7 +58,6 @@ namespace NLog
logger.Trace(content); logger.Trace(content);
} }
public static void Debug(this Logger logger, EndPoint local, EndPoint remote, int len, string header = null, string tailer = null) public static void Debug(this Logger logger, EndPoint local, EndPoint remote, int len, string header = null, string tailer = null)
{ {
if (logger.IsDebugEnabled) if (logger.IsDebugEnabled)


+ 4
- 0
shadowsocks-csharp/Controller/Service/Listener.cs View File

@@ -15,13 +15,17 @@ namespace Shadowsocks.Controller
public interface IService public interface IService
{ {
[Obsolete]
bool Handle(byte[] firstPacket, int length, Socket socket, object state); bool Handle(byte[] firstPacket, int length, Socket socket, object state);
public abstract bool Handle(CachedNetworkStream stream, object state);
void Stop(); void Stop();
} }
public abstract class Service : IService public abstract class Service : IService
{ {
[Obsolete]
public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state); public abstract bool Handle(byte[] firstPacket, int length, Socket socket, object state);
public abstract bool Handle(CachedNetworkStream stream, object state); public abstract bool Handle(CachedNetworkStream stream, object state);


+ 1
- 0
shadowsocks-csharp/Controller/Service/PACServer.cs View File

@@ -60,6 +60,7 @@ namespace Shadowsocks.Controller
return Handle(fp, len, stream.Socket, state); return Handle(fp, len, stream.Socket, state);
} }
[Obsolete]
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{ {
if (socket.ProtocolType != ProtocolType.Tcp) if (socket.ProtocolType != ProtocolType.Tcp)


+ 8
- 1
shadowsocks-csharp/Controller/Service/PortForwarder.cs View File

@@ -19,9 +19,16 @@ namespace Shadowsocks.Controller
{ {
byte[] fp = new byte[256]; byte[] fp = new byte[256];
int len = stream.ReadFirstBlock(fp); int len = stream.ReadFirstBlock(fp);
return Handle(fp, len, stream.Socket, state);
if (stream.Socket.ProtocolType != ProtocolType.Tcp)
{
return false;
}
new Handler().Start(fp, len, stream.Socket, _targetPort);
return true;
} }
[Obsolete]
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{ {
if (socket.ProtocolType != ProtocolType.Tcp) if (socket.ProtocolType != ProtocolType.Tcp)


+ 42
- 1
shadowsocks-csharp/Controller/Service/TCPRelay.cs View File

@@ -35,11 +35,52 @@ namespace Shadowsocks.Controller
public override bool Handle(CachedNetworkStream stream, object state) public override bool Handle(CachedNetworkStream stream, object state)
{ {
byte[] fp = new byte[256]; byte[] fp = new byte[256];
int len = stream.ReadFirstBlock(fp); int len = stream.ReadFirstBlock(fp);
return Handle(fp, len, stream.Socket, state);
var socket = stream.Socket;
if (socket.ProtocolType != ProtocolType.Tcp
|| (len < 2 || fp[0] != 5))
return false;
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
TCPHandler handler = new TCPHandler(_controller, _config, this, socket);
IList<TCPHandler> handlersToClose = new List<TCPHandler>();
lock (Handlers)
{
Handlers.Add(handler);
DateTime now = DateTime.Now;
if (now - _lastSweepTime > TimeSpan.FromSeconds(1))
{
_lastSweepTime = now;
foreach (TCPHandler handler1 in Handlers)
if (now - handler1.lastActivity > TimeSpan.FromSeconds(900))
handlersToClose.Add(handler1);
}
}
foreach (TCPHandler handler1 in handlersToClose)
{
logger.Debug("Closing timed out TCP connection.");
handler1.Close();
}
/*
* Start after we put it into Handlers set. Otherwise if it failed in handler.Start()
* then it will call handler.Close() before we add it into the set.
* Then the handler will never release until the next Handle call. Sometimes it will
* cause odd problems (especially during memory profiling).
*/
handler.Start(fp, len);
return true;
// return Handle(fp, len, stream.Socket, state);
} }
[Obsolete]
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{ {
if (socket.ProtocolType != ProtocolType.Tcp if (socket.ProtocolType != ProtocolType.Tcp


+ 1
- 0
shadowsocks-csharp/Controller/Service/UDPRelay.cs View File

@@ -32,6 +32,7 @@ namespace Shadowsocks.Controller
return Handle(fp, len, stream.Socket, state); return Handle(fp, len, stream.Socket, state);
} }
[Obsolete]
public override bool Handle(byte[] firstPacket, int length, Socket socket, object state) public override bool Handle(byte[] firstPacket, int length, Socket socket, object state)
{ {
if (socket.ProtocolType != ProtocolType.Udp) if (socket.ProtocolType != ProtocolType.Udp)


+ 18
- 13
shadowsocks-csharp/Encryption/AEAD/AEADEncryptor.cs View File

@@ -174,7 +174,7 @@ namespace Shadowsocks.Encryption.AEAD
cipher.CopyTo(tmp.Slice(bufPtr)); cipher.CopyTo(tmp.Slice(bufPtr));
int bufSize = tmp.Length; int bufSize = tmp.Length;
logger.Debug("---Start Decryption");
logger.Debug($"{instanceId} decrypt tcp, read salt: {!saltReady}");
if (!saltReady) if (!saltReady)
{ {
// check if we get the leading salt // check if we get the leading salt
@@ -187,12 +187,10 @@ namespace Shadowsocks.Encryption.AEAD
} }
saltReady = true; saltReady = true;
// buffer.Get(saltLen);
byte[] salt = tmp.Slice(0, saltLen).ToArray(); byte[] salt = tmp.Slice(0, saltLen).ToArray();
tmp = tmp.Slice(saltLen); tmp = tmp.Slice(saltLen);
InitCipher(salt, false); InitCipher(salt, false);
logger.Debug("get salt len " + saltLen);
} }
// handle chunks // handle chunks
@@ -202,7 +200,7 @@ namespace Shadowsocks.Encryption.AEAD
// check if we have any data // check if we have any data
if (bufSize <= 0) if (bufSize <= 0)
{ {
logger.Debug("No data in buffer");
logger.Trace("No data in buffer");
return outlength; return outlength;
} }
@@ -211,27 +209,32 @@ namespace Shadowsocks.Encryption.AEAD
{ {
// so we only have chunk length and its tag? // so we only have chunk length and its tag?
// wait more // wait more
logger.Trace($"{instanceId} not enough data to decrypt chunk. write {tmp.Length} byte back to buffer.");
tmp.CopyTo(buffer); tmp.CopyTo(buffer);
bufPtr = tmp.Length; bufPtr = tmp.Length;
return outlength; return outlength;
} }
logger.Trace($"{instanceId} try decrypt to offset {outlength}");
int len = ChunkDecrypt(plain.Slice(outlength), tmp); int len = ChunkDecrypt(plain.Slice(outlength), tmp);
if (len <= 0) if (len <= 0)
{ {
logger.Trace($"{instanceId} no chunk decrypted, write {tmp.Length} byte back to buffer.");
// no chunk decrypted // no chunk decrypted
tmp.CopyTo(buffer); tmp.CopyTo(buffer);
bufPtr = tmp.Length; bufPtr = tmp.Length;
return outlength; return outlength;
} }
logger.Trace($"{instanceId} decrypted {len} to offset {outlength}");
// drop decrypted data // drop decrypted data
tmp = tmp.Slice(ChunkLengthBytes + tagLen + len + tagLen); tmp = tmp.Slice(ChunkLengthBytes + tagLen + len + tagLen);
outlength += len; outlength += len;
logger.Debug("aead dec outlength " + outlength);
// logger.Debug("aead dec outlength " + outlength);
if (outlength + 100 > TCPHandler.BufferSize) if (outlength + 100 > TCPHandler.BufferSize)
{ {
logger.Debug("dec outbuf almost full, giving up");
logger.Trace($"{instanceId} output almost full, write {tmp.Length} byte back to buffer.");
tmp.CopyTo(buffer); tmp.CopyTo(buffer);
bufPtr = tmp.Length; bufPtr = tmp.Length;
return outlength; return outlength;
@@ -240,7 +243,8 @@ namespace Shadowsocks.Encryption.AEAD
// check if we already done all of them // check if we already done all of them
if (bufSize <= 0) if (bufSize <= 0)
{ {
logger.Debug("No data in _decCircularBuffer, already all done");
bufPtr = 0;
logger.Debug($"{instanceId} no data in buffer, already all done");
return outlength; return outlength;
} }
} }
@@ -266,7 +270,7 @@ namespace Shadowsocks.Encryption.AEAD
#endregion #endregion
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
[MethodImpl(MethodImplOptions.Synchronized | MethodImplOptions.AggressiveOptimization)]
private int ChunkEncrypt(ReadOnlySpan<byte> plain, Span<byte> cipher) private int ChunkEncrypt(ReadOnlySpan<byte> plain, Span<byte> cipher)
{ {
if (plain.Length > ChunkLengthMask) if (plain.Length > ChunkLengthMask)
@@ -284,7 +288,7 @@ namespace Shadowsocks.Encryption.AEAD
return cipherLenSize + cipherDataSize; return cipherLenSize + cipherDataSize;
} }
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
[MethodImpl(MethodImplOptions.Synchronized | MethodImplOptions.AggressiveOptimization)]
private int ChunkDecrypt(Span<byte> plain, ReadOnlySpan<byte> cipher) private int ChunkDecrypt(Span<byte> plain, ReadOnlySpan<byte> cipher)
{ {
// try to dec chunk len // try to dec chunk len
@@ -294,14 +298,14 @@ namespace Shadowsocks.Encryption.AEAD
if (chunkLength > ChunkLengthMask) if (chunkLength > ChunkLengthMask)
{ {
// we get invalid chunk // we get invalid chunk
logger.Error($"Invalid chunk length: {chunkLength}");
logger.Error($"{instanceId} Invalid chunk length: {chunkLength}");
throw new CryptoErrorException(); throw new CryptoErrorException();
} }
logger.Debug("Get the real chunk len:" + chunkLength);
// logger.Debug("Get the real chunk len:" + chunkLength);
int bufSize = cipher.Length; int bufSize = cipher.Length;
if (bufSize < ChunkLengthBytes + tagLen /* we haven't remove them */+ chunkLength + tagLen) if (bufSize < ChunkLengthBytes + tagLen /* we haven't remove them */+ chunkLength + tagLen)
{ {
logger.Debug("No data to decrypt one chunk");
logger.Debug($"{instanceId} need {ChunkLengthBytes + tagLen + chunkLength + tagLen}, but have {cipher.Length}");
return 0; return 0;
} }
CryptoUtils.SodiumIncrement(nonce); CryptoUtils.SodiumIncrement(nonce);
@@ -309,6 +313,7 @@ namespace Shadowsocks.Encryption.AEAD
// drop chunk len and its tag from buffer // drop chunk len and its tag from buffer
int len = CipherDecrypt(plain, cipher.Slice(ChunkLengthBytes + tagLen, chunkLength + tagLen)); int len = CipherDecrypt(plain, cipher.Slice(ChunkLengthBytes + tagLen, chunkLength + tagLen));
CryptoUtils.SodiumIncrement(nonce); CryptoUtils.SodiumIncrement(nonce);
logger.Trace($"{instanceId} decrypted {len} byte chunk used {ChunkLengthBytes + tagLen + chunkLength + tagLen} from {cipher.Length}");
return len; return len;
} }
} }

Loading…
Cancel
Save