diff --git a/src/Discord.Net.Core/Audio/IAudioClient.cs b/src/Discord.Net.Core/Audio/IAudioClient.cs
index 472ad32f1..4a6ae2e27 100644
--- a/src/Discord.Net.Core/Audio/IAudioClient.cs
+++ b/src/Discord.Net.Core/Audio/IAudioClient.cs
@@ -14,7 +14,7 @@ namespace Discord.Audio
/// Gets the estimated round-trip latency, in milliseconds, to the gateway server.
int Latency { get; }
- Task DisconnectAsync();
+ Task StopAsync();
///
/// Creates a new outgoing stream accepting Opus-encoded data.
diff --git a/src/Discord.Net.Core/IDiscordClient.cs b/src/Discord.Net.Core/IDiscordClient.cs
index b9a08d32d..1c5ec41c1 100644
--- a/src/Discord.Net.Core/IDiscordClient.cs
+++ b/src/Discord.Net.Core/IDiscordClient.cs
@@ -10,8 +10,8 @@ namespace Discord
ConnectionState ConnectionState { get; }
ISelfUser CurrentUser { get; }
- Task ConnectAsync();
- Task DisconnectAsync();
+ Task StartAsync();
+ Task StopAsync();
Task GetApplicationInfoAsync();
diff --git a/src/Discord.Net.Core/Logging/LogMessage.cs b/src/Discord.Net.Core/Logging/LogMessage.cs
index 9c3dfcfea..d1b3782be 100644
--- a/src/Discord.Net.Core/Logging/LogMessage.cs
+++ b/src/Discord.Net.Core/Logging/LogMessage.cs
@@ -19,7 +19,7 @@ namespace Discord
}
public override string ToString() => ToString(null);
- public string ToString(StringBuilder builder = null, bool fullException = true, bool prependTimestamp = true, DateTimeKind timestampKind = DateTimeKind.Local, int? padSource = 9)
+ public string ToString(StringBuilder builder = null, bool fullException = true, bool prependTimestamp = true, DateTimeKind timestampKind = DateTimeKind.Local, int? padSource = 11)
{
string sourceName = Source;
string message = Message;
@@ -87,8 +87,11 @@ namespace Discord
}
if (exMessage != null)
{
- builder.Append(':');
- builder.AppendLine();
+ if (!string.IsNullOrEmpty(Message))
+ {
+ builder.Append(':');
+ builder.AppendLine();
+ }
builder.Append(exMessage);
}
diff --git a/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj b/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj
index da0c8b0fd..829951d19 100644
--- a/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj
+++ b/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj
@@ -6,7 +6,7 @@
netstandard1.6
Discord.Net.DebugTools
RogueException
- A Discord.Net extension adding random helper classes for diagnosing issues.
+ A Discord.Net extension adding some helper classes for diagnosing issues.
discord;discordapp
https://github.com/RogueException/Discord.Net
http://opensource.org/licenses/MIT
diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs
index 8948c87dc..80c4cb598 100644
--- a/src/Discord.Net.Rest/BaseDiscordClient.cs
+++ b/src/Discord.Net.Rest/BaseDiscordClient.cs
@@ -18,10 +18,9 @@ namespace Discord.Rest
public event Func LoggedOut { add { _loggedOutEvent.Add(value); } remove { _loggedOutEvent.Remove(value); } }
private readonly AsyncEvent> _loggedOutEvent = new AsyncEvent>();
- internal readonly Logger _restLogger, _queueLogger;
- internal readonly SemaphoreSlim _connectionLock;
- private bool _isFirstLogin;
- private bool _isDisposed;
+ internal readonly Logger _restLogger;
+ private readonly SemaphoreSlim _stateLock;
+ private bool _isFirstLogin, _isDisposed;
internal API.DiscordRestApiClient ApiClient { get; }
internal LogManager LogManager { get; }
@@ -35,17 +34,16 @@ namespace Discord.Rest
LogManager = new LogManager(config.LogLevel);
LogManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false);
- _connectionLock = new SemaphoreSlim(1, 1);
+ _stateLock = new SemaphoreSlim(1, 1);
_restLogger = LogManager.CreateLogger("Rest");
- _queueLogger = LogManager.CreateLogger("Queue");
_isFirstLogin = config.DisplayInitialLog;
ApiClient.RequestQueue.RateLimitTriggered += async (id, info) =>
{
if (info == null)
- await _queueLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
+ await _restLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
else
- await _queueLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
+ await _restLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false);
};
ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false);
}
@@ -53,12 +51,12 @@ namespace Discord.Rest
///
public async Task LoginAsync(TokenType tokenType, string token, bool validateToken = true)
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
+ await _stateLock.WaitAsync().ConfigureAwait(false);
try
{
await LoginInternalAsync(tokenType, token).ConfigureAwait(false);
}
- finally { _connectionLock.Release(); }
+ finally { _stateLock.Release(); }
}
private async Task LoginInternalAsync(TokenType tokenType, string token)
{
@@ -86,17 +84,17 @@ namespace Discord.Rest
await _loggedInEvent.InvokeAsync().ConfigureAwait(false);
}
- protected virtual Task OnLoginAsync(TokenType tokenType, string token) { return Task.Delay(0); }
+ internal virtual Task OnLoginAsync(TokenType tokenType, string token) { return Task.Delay(0); }
///
public async Task LogoutAsync()
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
+ await _stateLock.WaitAsync().ConfigureAwait(false);
try
{
await LogoutInternalAsync().ConfigureAwait(false);
}
- finally { _connectionLock.Release(); }
+ finally { _stateLock.Release(); }
}
private async Task LogoutInternalAsync()
{
@@ -111,7 +109,7 @@ namespace Discord.Rest
await _loggedOutEvent.InvokeAsync().ConfigureAwait(false);
}
- protected virtual Task OnLogoutAsync() { return Task.Delay(0); }
+ internal virtual Task OnLogoutAsync() { return Task.Delay(0); }
internal virtual void Dispose(bool disposing)
{
@@ -161,8 +159,9 @@ namespace Discord.Rest
Task IDiscordClient.GetVoiceRegionAsync(string id)
=> Task.FromResult(null);
- Task IDiscordClient.ConnectAsync() { throw new NotSupportedException(); }
- Task IDiscordClient.DisconnectAsync() { throw new NotSupportedException(); }
-
+ Task IDiscordClient.StartAsync()
+ => Task.Delay(0);
+ Task IDiscordClient.StopAsync()
+ => Task.Delay(0);
}
}
diff --git a/src/Discord.Net.Rest/ConnectionManager.cs b/src/Discord.Net.Rest/ConnectionManager.cs
new file mode 100644
index 000000000..ab1f4790c
--- /dev/null
+++ b/src/Discord.Net.Rest/ConnectionManager.cs
@@ -0,0 +1,199 @@
+using Discord.Logging;
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Discord
+{
+ internal class ConnectionManager
+ {
+ public event Func Connected { add { _connectedEvent.Add(value); } remove { _connectedEvent.Remove(value); } }
+ private readonly AsyncEvent> _connectedEvent = new AsyncEvent>();
+ public event Func Disconnected { add { _disconnectedEvent.Add(value); } remove { _disconnectedEvent.Remove(value); } }
+ private readonly AsyncEvent> _disconnectedEvent = new AsyncEvent>();
+
+ private readonly SemaphoreSlim _stateLock;
+ private readonly Logger _logger;
+ private readonly int _connectionTimeout;
+ private readonly Func _onConnecting;
+ private readonly Func _onDisconnecting;
+
+ private TaskCompletionSource _connectionPromise, _readyPromise;
+ private CancellationTokenSource _combinedCancelToken, _reconnectCancelToken, _connectionCancelToken;
+ private Task _task;
+
+ public ConnectionState State { get; private set; }
+ public CancellationToken CancelToken { get; private set; }
+
+ public bool IsCompleted => _readyPromise.Task.IsCompleted;
+
+ internal ConnectionManager(SemaphoreSlim stateLock, Logger logger, int connectionTimeout,
+ Func onConnecting, Func onDisconnecting, Action> clientDisconnectHandler)
+ {
+ _stateLock = stateLock;
+ _logger = logger;
+ _connectionTimeout = connectionTimeout;
+ _onConnecting = onConnecting;
+ _onDisconnecting = onDisconnecting;
+
+ clientDisconnectHandler(ex =>
+ {
+ if (ex != null)
+ Error(new Exception("WebSocket connection was closed", ex));
+ else
+ Error(new Exception("WebSocket connection was closed"));
+ return Task.Delay(0);
+ });
+ }
+
+ public virtual async Task StartAsync()
+ {
+ await AcquireConnectionLock().ConfigureAwait(false);
+ var reconnectCancelToken = new CancellationTokenSource();
+ _reconnectCancelToken = new CancellationTokenSource();
+ _task = Task.Run(async () =>
+ {
+ try
+ {
+ Random jitter = new Random();
+ int nextReconnectDelay = 1000;
+ while (!reconnectCancelToken.IsCancellationRequested)
+ {
+ try
+ {
+ await ConnectAsync(reconnectCancelToken).ConfigureAwait(false);
+ nextReconnectDelay = 1000; //Reset delay
+ await _connectionPromise.Task.ConfigureAwait(false);
+ }
+ catch (OperationCanceledException ex)
+ {
+ Cancel(); //In case this exception didn't come from another Error call
+ await DisconnectAsync(ex, !reconnectCancelToken.IsCancellationRequested).ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ Error(ex); //In case this exception didn't come from another Error call
+ if (!reconnectCancelToken.IsCancellationRequested)
+ {
+ await _logger.WarningAsync(ex).ConfigureAwait(false);
+ await DisconnectAsync(ex, true).ConfigureAwait(false);
+ }
+ else
+ {
+ await _logger.ErrorAsync(ex).ConfigureAwait(false);
+ await DisconnectAsync(ex, false).ConfigureAwait(false);
+ }
+ }
+
+ if (!reconnectCancelToken.IsCancellationRequested)
+ {
+ //Wait before reconnecting
+ await Task.Delay(nextReconnectDelay, reconnectCancelToken.Token).ConfigureAwait(false);
+ nextReconnectDelay = (nextReconnectDelay * 2) + jitter.Next(-250, 250);
+ if (nextReconnectDelay > 60000)
+ nextReconnectDelay = 60000;
+ }
+ }
+ }
+ finally { _stateLock.Release(); }
+ });
+ }
+ public virtual async Task StopAsync()
+ {
+ Cancel();
+ var task = _task;
+ if (task != null)
+ await task.ConfigureAwait(false);
+ }
+
+ private async Task ConnectAsync(CancellationTokenSource reconnectCancelToken)
+ {
+ _connectionCancelToken = new CancellationTokenSource();
+ _combinedCancelToken = CancellationTokenSource.CreateLinkedTokenSource(_connectionCancelToken.Token, reconnectCancelToken.Token);
+ CancelToken = _combinedCancelToken.Token;
+
+ _connectionPromise = new TaskCompletionSource();
+ State = ConnectionState.Connecting;
+ await _logger.InfoAsync("Connecting").ConfigureAwait(false);
+
+ try
+ {
+ var readyPromise = new TaskCompletionSource();
+ _readyPromise = readyPromise;
+
+ //Abort connection on timeout
+ var cancelToken = CancelToken;
+ var _ = Task.Run(async () =>
+ {
+ try
+ {
+ await Task.Delay(_connectionTimeout, cancelToken).ConfigureAwait(false);
+ readyPromise.TrySetException(new TimeoutException());
+ }
+ catch (OperationCanceledException) { }
+ });
+
+ await _onConnecting().ConfigureAwait(false);
+
+ await _logger.InfoAsync("Connected").ConfigureAwait(false);
+ State = ConnectionState.Connected;
+ await _logger.DebugAsync("Raising Event").ConfigureAwait(false);
+ await _connectedEvent.InvokeAsync().ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ Error(ex);
+ throw;
+ }
+ }
+ private async Task DisconnectAsync(Exception ex, bool isReconnecting)
+ {
+ if (State == ConnectionState.Disconnected) return;
+ State = ConnectionState.Disconnecting;
+ await _logger.InfoAsync("Disconnecting").ConfigureAwait(false);
+
+ await _onDisconnecting(ex).ConfigureAwait(false);
+
+ await _logger.InfoAsync("Disconnected").ConfigureAwait(false);
+ State = ConnectionState.Disconnected;
+ await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false);
+ }
+
+ public async Task CompleteAsync()
+ {
+ await _readyPromise.TrySetResultAsync(true).ConfigureAwait(false);
+ }
+ public async Task WaitAsync()
+ {
+ await _readyPromise.Task.ConfigureAwait(false);
+ }
+
+ public void Cancel()
+ {
+ _readyPromise?.TrySetCanceled();
+ _connectionPromise?.TrySetCanceled();
+ _reconnectCancelToken?.Cancel();
+ _connectionCancelToken?.Cancel();
+ }
+ public void Error(Exception ex)
+ {
+ _readyPromise.TrySetException(ex);
+ _connectionPromise.TrySetException(ex);
+ _connectionCancelToken?.Cancel();
+ }
+ public void CriticalError(Exception ex)
+ {
+ _reconnectCancelToken?.Cancel();
+ Error(ex);
+ }
+ private async Task AcquireConnectionLock()
+ {
+ while (true)
+ {
+ await StopAsync().ConfigureAwait(false);
+ if (await _stateLock.WaitAsync(0).ConfigureAwait(false))
+ break;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs
index 0727576bf..0ff1a4821 100644
--- a/src/Discord.Net.Rest/DiscordRestClient.cs
+++ b/src/Discord.Net.Rest/DiscordRestClient.cs
@@ -16,14 +16,19 @@ namespace Discord.Rest
private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config)
=> new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent);
+ internal override void Dispose(bool disposing)
+ {
+ if (disposing)
+ ApiClient.Dispose();
+ }
- protected override async Task OnLoginAsync(TokenType tokenType, string token)
+ internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
var user = await ApiClient.GetMyUserAsync(new RequestOptions { RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false);
ApiClient.CurrentUserId = user.Id;
base.CurrentUser = RestSelfUser.Create(this, user);
}
- protected override Task OnLogoutAsync()
+ internal override Task OnLogoutAsync()
{
_applicationInfo = null;
return Task.Delay(0);
diff --git a/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs b/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs
index 2a9ae21bf..d3c50a5ec 100644
--- a/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs
+++ b/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs
@@ -12,12 +12,12 @@ namespace Discord.Rpc
remove { _connectedEvent.Remove(value); }
}
private readonly AsyncEvent> _connectedEvent = new AsyncEvent>();
- public event Func Disconnected
+ public event Func Disconnected
{
add { _disconnectedEvent.Add(value); }
remove { _disconnectedEvent.Remove(value); }
}
- private readonly AsyncEvent> _disconnectedEvent = new AsyncEvent>();
+ private readonly AsyncEvent> _disconnectedEvent = new AsyncEvent>();
public event Func Ready
{
add { _readyEvent.Add(value); }
diff --git a/src/Discord.Net.Rpc/DiscordRpcClient.cs b/src/Discord.Net.Rpc/DiscordRpcClient.cs
index 845ba97c6..5235c98d4 100644
--- a/src/Discord.Net.Rpc/DiscordRpcClient.cs
+++ b/src/Discord.Net.Rpc/DiscordRpcClient.cs
@@ -8,28 +8,22 @@ using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
-using System.Threading;
using System.Threading.Tasks;
+using System.Threading;
namespace Discord.Rpc
{
public partial class DiscordRpcClient : BaseDiscordClient, IDiscordClient
{
- private readonly Logger _rpcLogger;
private readonly JsonSerializer _serializer;
-
- private TaskCompletionSource _connectTask;
- private CancellationTokenSource _cancelToken, _reconnectCancelToken;
- private Task _reconnectTask;
- private bool _canReconnect;
+ private readonly ConnectionManager _connection;
+ private readonly Logger _rpcLogger;
+ private readonly SemaphoreSlim _stateLock, _authorizeLock;
public ConnectionState ConnectionState { get; private set; }
public IReadOnlyCollection Scopes { get; private set; }
public DateTimeOffset TokenExpiresAt { get; private set; }
- //From DiscordRpcConfig
- internal int ConnectionTimeout { get; private set; }
-
internal new API.DiscordRpcApiClient ApiClient => base.ApiClient as API.DiscordRpcApiClient;
public new RestSelfUser CurrentUser { get { return base.CurrentUser as RestSelfUser; } private set { base.CurrentUser = value; } }
public RestApplication ApplicationInfo { get; private set; }
@@ -41,8 +35,11 @@ namespace Discord.Rpc
public DiscordRpcClient(string clientId, string origin, DiscordRpcConfig config)
: base(config, CreateApiClient(clientId, origin, config))
{
- ConnectionTimeout = config.ConnectionTimeout;
+ _stateLock = new SemaphoreSlim(1, 1);
+ _authorizeLock = new SemaphoreSlim(1, 1);
_rpcLogger = LogManager.CreateLogger("RPC");
+ _connection = new ConnectionManager(_stateLock, _rpcLogger, config.ConnectionTimeout,
+ OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) =>
@@ -53,177 +50,52 @@ namespace Discord.Rpc
ApiClient.SentRpcMessage += async opCode => await _rpcLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false);
ApiClient.ReceivedRpcEvent += ProcessMessageAsync;
- ApiClient.Disconnected += async ex =>
- {
- if (ex != null)
- {
- await _rpcLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
- await StartReconnectAsync(ex).ConfigureAwait(false);
- }
- else
- await _rpcLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
- };
}
private static API.DiscordRpcApiClient CreateApiClient(string clientId, string origin, DiscordRpcConfig config)
=> new API.DiscordRpcApiClient(clientId, DiscordRestConfig.UserAgent, origin, config.RestClientProvider, config.WebSocketProvider);
-
- ///
- public async Task ConnectAsync()
+ internal override void Dispose(bool disposing)
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
+ if (disposing)
{
- await ConnectInternalAsync(false).ConfigureAwait(false);
+ StopAsync().GetAwaiter().GetResult();
+ ApiClient.Dispose();
}
- finally { _connectionLock.Release(); }
}
- private async Task ConnectInternalAsync(bool isReconnecting)
- {
- if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
- _reconnectCancelToken.Cancel();
- var state = ConnectionState;
- if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
- await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
+ public Task StartAsync() => _connection.StartAsync();
+ public Task StopAsync() => _connection.StopAsync();
- ConnectionState = ConnectionState.Connecting;
- await _rpcLogger.InfoAsync("Connecting").ConfigureAwait(false);
- try
- {
- var connectTask = new TaskCompletionSource();
- _connectTask = connectTask;
- _cancelToken = new CancellationTokenSource();
-
- //Abort connection on timeout
- var _ = Task.Run(async () =>
- {
- await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
- connectTask.TrySetException(new TimeoutException());
- });
-
- await ApiClient.ConnectAsync().ConfigureAwait(false);
- await _connectedEvent.InvokeAsync().ConfigureAwait(false);
-
- await _connectTask.Task.ConfigureAwait(false);
- if (!isReconnecting)
- _canReconnect = true;
- ConnectionState = ConnectionState.Connected;
- await _rpcLogger.InfoAsync("Connected").ConfigureAwait(false);
- }
- catch (Exception)
- {
- await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
- throw;
- }
- }
- ///
- public async Task DisconnectAsync()
+ private async Task OnConnectingAsync()
{
- if (_connectTask?.TrySetCanceled() ?? false) return;
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await DisconnectInternalAsync(null, false).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
+ await _rpcLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
+ await ApiClient.ConnectAsync().ConfigureAwait(false);
+
+ await _connection.WaitAsync().ConfigureAwait(false);
}
- private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting)
+ private async Task OnDisconnectingAsync(Exception ex)
{
- if (!isReconnecting)
- {
- _canReconnect = false;
-
- if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
- _reconnectCancelToken.Cancel();
- }
-
- if (ConnectionState == ConnectionState.Disconnected) return;
- ConnectionState = ConnectionState.Disconnecting;
- await _rpcLogger.InfoAsync("Disconnecting").ConfigureAwait(false);
-
- await _rpcLogger.DebugAsync("Disconnecting - CancelToken").ConfigureAwait(false);
- //Signal tasks to complete
- try { _cancelToken.Cancel(); } catch { }
-
- await _rpcLogger.DebugAsync("Disconnecting - ApiClient").ConfigureAwait(false);
- //Disconnect from server
+ await _rpcLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
await ApiClient.DisconnectAsync().ConfigureAwait(false);
-
- ConnectionState = ConnectionState.Disconnected;
- await _rpcLogger.InfoAsync("Disconnected").ConfigureAwait(false);
-
- await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);
}
- private async Task StartReconnectAsync(Exception ex)
+ public async Task AuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null)
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
+ await _authorizeLock.WaitAsync().ConfigureAwait(false);
try
{
- if (!_canReconnect || _reconnectTask != null) return;
- _reconnectCancelToken = new CancellationTokenSource();
- _reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token);
- }
- finally { _connectionLock.Release(); }
- }
- private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken)
- {
- if (ex == null)
- {
- if (_connectTask?.TrySetCanceled() ?? false) return;
+ await _connection.StartAsync().ConfigureAwait(false);
+ await _connection.WaitAsync().ConfigureAwait(false);
+ var result = await ApiClient.SendAuthorizeAsync(scopes, rpcToken, options).ConfigureAwait(false);
+ await _connection.StopAsync().ConfigureAwait(false);
+ return result.Code;
}
- else
+ finally
{
- if (_connectTask?.TrySetException(ex) ?? false) return;
- }
-
- try
- {
- Random jitter = new Random();
- int nextReconnectDelay = 1000;
- while (true)
- {
- await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false);
- nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250);
- if (nextReconnectDelay > 60000)
- nextReconnectDelay = 60000;
-
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- if (cancelToken.IsCancellationRequested) return;
- await ConnectInternalAsync(true).ConfigureAwait(false);
- _reconnectTask = null;
- return;
- }
- catch (Exception ex2)
- {
- await _rpcLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
- }
- }
- catch (OperationCanceledException)
- {
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await _rpcLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false);
- _reconnectTask = null;
- }
- finally { _connectionLock.Release(); }
+ _authorizeLock.Release();
}
}
- public async Task AuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null)
- {
- await ConnectAsync().ConfigureAwait(false);
- var result = await ApiClient.SendAuthorizeAsync(scopes, rpcToken, options).ConfigureAwait(false);
- await DisconnectAsync().ConfigureAwait(false);
- return result.Code;
- }
-
public async Task SubscribeGlobal(RpcGlobalEvent evnt, RequestOptions options = null)
{
await ApiClient.SendGlobalSubscribeAsync(GetEventName(evnt), options).ConfigureAwait(false);
@@ -439,8 +311,8 @@ namespace Discord.Rpc
ApplicationInfo = RestApplication.Create(this, response.Application);
Scopes = response.Scopes;
TokenExpiresAt = response.Expires;
-
- var __ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
+
+ var __ = _connection.CompleteAsync();
await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false);
}
catch (Exception ex)
@@ -452,7 +324,7 @@ namespace Discord.Rpc
}
else
{
- var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
+ var _ = _connection.CompleteAsync();
await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false);
}
}
@@ -592,6 +464,13 @@ namespace Discord.Rpc
}
//IDiscordClient
+ ConnectionState IDiscordClient.ConnectionState => _connection.State;
+
Task IDiscordClient.GetApplicationInfoAsync() => Task.FromResult(ApplicationInfo);
+
+ async Task IDiscordClient.StartAsync()
+ => await StartAsync().ConfigureAwait(false);
+ async Task IDiscordClient.StopAsync()
+ => await StopAsync().ConfigureAwait(false);
}
}
diff --git a/src/Discord.Net.WebSocket/Audio/AudioClient.cs b/src/Discord.Net.WebSocket/Audio/AudioClient.cs
index 5bedf1786..5404227f2 100644
--- a/src/Discord.Net.WebSocket/Audio/AudioClient.cs
+++ b/src/Discord.Net.WebSocket/Audio/AudioClient.cs
@@ -5,6 +5,7 @@ using Discord.WebSocket;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
+using System.Collections.Concurrent;
using System.Linq;
using System.Text;
using System.Threading;
@@ -12,6 +13,7 @@ using System.Threading.Tasks;
namespace Discord.Audio
{
+ //TODO: Add audio reconnecting
internal class AudioClient : IAudioClient, IDisposable
{
public event Func Connected
@@ -34,34 +36,37 @@ namespace Discord.Audio
private readonly AsyncEvent> _latencyUpdatedEvent = new AsyncEvent>();
private readonly Logger _audioLogger;
- internal readonly SemaphoreSlim _connectionLock;
private readonly JsonSerializer _serializer;
+ private readonly ConnectionManager _connection;
+ private readonly SemaphoreSlim _stateLock;
+ private readonly ConcurrentQueue _heartbeatTimes;
- private TaskCompletionSource _connectTask;
- private CancellationTokenSource _cancelTokenSource;
private Task _heartbeatTask;
- private long _heartbeatTime;
- private string _url;
+ private long _lastMessageTime;
+ private string _url, _sessionId, _token;
+ private ulong _userId;
private uint _ssrc;
private byte[] _secretKey;
- private bool _isDisposed;
public SocketGuild Guild { get; }
public DiscordVoiceAPIClient ApiClient { get; private set; }
- public ConnectionState ConnectionState { get; private set; }
public int Latency { get; private set; }
private DiscordSocketClient Discord => Guild.Discord;
+ public ConnectionState ConnectionState => _connection.State;
/// Creates a new REST/WebSocket discord client.
internal AudioClient(SocketGuild guild, int id)
{
Guild = guild;
- _audioLogger = Discord.LogManager.CreateLogger($"Audio #{id}");
-
- _connectionLock = new SemaphoreSlim(1, 1);
+ _stateLock = new SemaphoreSlim(1, 1);
+ _connection = new ConnectionManager(_stateLock, _audioLogger, 30000,
+ OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
+ _heartbeatTimes = new ConcurrentQueue();
+ _audioLogger = Discord.LogManager.CreateLogger($"Audio #{id}");
+
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) =>
{
@@ -76,83 +81,28 @@ namespace Discord.Audio
//ApiClient.SentData += async bytes => await _audioLogger.DebugAsync($"Sent {bytes} Bytes").ConfigureAwait(false);
ApiClient.ReceivedEvent += ProcessMessageAsync;
ApiClient.ReceivedPacket += ProcessPacketAsync;
- ApiClient.Disconnected += async ex =>
- {
- if (ex != null)
- await _audioLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
- else
- await _audioLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
- };
LatencyUpdated += async (old, val) => await _audioLogger.VerboseAsync($"Latency = {val} ms").ConfigureAwait(false);
}
- ///
- internal async Task ConnectAsync(string url, ulong userId, string sessionId, string token)
+ internal async Task StartAsync(string url, ulong userId, string sessionId, string token)
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await ConnectInternalAsync(url, userId, sessionId, token).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
+ _url = url;
+ _userId = userId;
+ _sessionId = sessionId;
+ _token = token;
+ await _connection.StartAsync().ConfigureAwait(false);
}
- private async Task ConnectInternalAsync(string url, ulong userId, string sessionId, string token)
- {
- var state = ConnectionState;
- if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
- await DisconnectInternalAsync(null).ConfigureAwait(false);
-
- ConnectionState = ConnectionState.Connecting;
- await _audioLogger.InfoAsync("Connecting").ConfigureAwait(false);
- try
- {
- _url = url;
- _connectTask = new TaskCompletionSource();
- _cancelTokenSource = new CancellationTokenSource();
-
- await ApiClient.ConnectAsync("wss://" + url).ConfigureAwait(false);
- await ApiClient.SendIdentityAsync(userId, sessionId, token).ConfigureAwait(false);
- await _connectTask.Task.ConfigureAwait(false);
+ public async Task StopAsync()
+ => await _connection.StopAsync().ConfigureAwait(false);
- await _connectedEvent.InvokeAsync().ConfigureAwait(false);
- ConnectionState = ConnectionState.Connected;
- await _audioLogger.InfoAsync("Connected").ConfigureAwait(false);
- }
- catch (Exception)
- {
- await DisconnectInternalAsync(null).ConfigureAwait(false);
- throw;
- }
- }
- ///
- public async Task DisconnectAsync()
+ private async Task OnConnectingAsync()
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await DisconnectInternalAsync(null).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
+ await ApiClient.ConnectAsync("wss://" + _url).ConfigureAwait(false);
+ await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false);
}
- private async Task DisconnectAsync(Exception ex)
+ private async Task OnDisconnectingAsync(Exception ex)
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await DisconnectInternalAsync(ex).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
- }
- private async Task DisconnectInternalAsync(Exception ex)
- {
- if (ConnectionState == ConnectionState.Disconnected) return;
- ConnectionState = ConnectionState.Disconnecting;
- await _audioLogger.InfoAsync("Disconnecting").ConfigureAwait(false);
-
- //Signal tasks to complete
- try { _cancelTokenSource.Cancel(); } catch { }
-
//Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false);
@@ -162,17 +112,17 @@ namespace Discord.Audio
await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null;
- ConnectionState = ConnectionState.Disconnected;
- await _audioLogger.InfoAsync("Disconnected").ConfigureAwait(false);
- await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);
-
await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false);
+
+ long time;
+ while (_heartbeatTimes.TryDequeue(out time)) { }
+ _lastMessageTime = 0;
}
public AudioOutStream CreateOpusStream(int samplesPerFrame, int bufferMillis)
{
CheckSamplesPerFrame(samplesPerFrame);
- var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _cancelTokenSource.Token);
+ var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _connection.CancelToken);
return new RTPWriteStream(target, _secretKey, samplesPerFrame, _ssrc);
}
public AudioOutStream CreateDirectOpusStream(int samplesPerFrame)
@@ -184,7 +134,7 @@ namespace Discord.Audio
public AudioOutStream CreatePCMStream(int samplesPerFrame, int channels, int? bitrate, int bufferMillis)
{
CheckSamplesPerFrame(samplesPerFrame);
- var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _cancelTokenSource.Token);
+ var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _connection.CancelToken);
return new OpusEncodeStream(target, _secretKey, channels, samplesPerFrame, _ssrc, bitrate);
}
public AudioOutStream CreateDirectPCMStream(int samplesPerFrame, int channels, int? bitrate)
@@ -202,6 +152,8 @@ namespace Discord.Audio
private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload)
{
+ _lastMessageTime = Environment.TickCount;
+
try
{
switch (opCode)
@@ -216,8 +168,7 @@ namespace Discord.Audio
if (!data.Modes.Contains(DiscordVoiceAPIClient.Mode))
throw new InvalidOperationException($"Discord does not support {DiscordVoiceAPIClient.Mode}");
- _heartbeatTime = 0;
- _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelTokenSource.Token);
+ _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.CancelToken);
ApiClient.SetUdpEndpoint(_url, data.Port);
await ApiClient.SendDiscoveryAsync(_ssrc).ConfigureAwait(false);
@@ -234,19 +185,17 @@ namespace Discord.Audio
_secretKey = data.SecretKey;
await ApiClient.SendSetSpeaking(true).ConfigureAwait(false);
- var _ = _connectTask.TrySetResultAsync(true);
+ var _ = _connection.CompleteAsync();
}
break;
case VoiceOpCode.HeartbeatAck:
{
await _audioLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false);
- var heartbeatTime = _heartbeatTime;
- if (heartbeatTime != 0)
+ long time;
+ if (_heartbeatTimes.TryDequeue(out time))
{
- int latency = (int)(Environment.TickCount - _heartbeatTime);
- _heartbeatTime = 0;
-
+ int latency = (int)(Environment.TickCount - time);
int before = Latency;
Latency = latency;
@@ -267,7 +216,7 @@ namespace Discord.Audio
}
private async Task ProcessPacketAsync(byte[] packet)
{
- if (!_connectTask.Task.IsCompleted)
+ if (!_connection.IsCompleted)
{
if (packet.Length == 70)
{
@@ -291,33 +240,50 @@ namespace Discord.Audio
//Clean this up when Discord's session patch is live
try
{
+ await _audioLogger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
while (!cancelToken.IsCancellationRequested)
{
+ var now = Environment.TickCount;
+
+ //Did server respond to our last heartbeat, or are we still receiving messages (long load?)
+ if (_heartbeatTimes.Count != 0 && (now - _lastMessageTime) > intervalMillis &&
+ ConnectionState == ConnectionState.Connected)
+ {
+ _connection.Error(new Exception("Server missed last heartbeat"));
+ return;
+ }
+ _heartbeatTimes.Enqueue(now);
+
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
- if (_heartbeatTime != 0) //Server never responded to our last heartbeat
+ try
{
- if (ConnectionState == ConnectionState.Connected)
- {
- await _audioLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
- await DisconnectInternalAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false);
- return;
- }
+ await ApiClient.SendHeartbeatAsync().ConfigureAwait(false);
}
- else
- _heartbeatTime = Environment.TickCount;
- await ApiClient.SendHeartbeatAsync().ConfigureAwait(false);
+ catch (Exception ex)
+ {
+ await _audioLogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
+ }
+
+ await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
}
+ await _audioLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
+ }
+ catch (OperationCanceledException)
+ {
+ await _audioLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ await _audioLogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}
- catch (OperationCanceledException) { }
}
internal void Dispose(bool disposing)
{
- if (disposing && !_isDisposed)
+ if (disposing)
{
- _isDisposed = true;
- DisconnectInternalAsync(null).GetAwaiter().GetResult();
+ StopAsync().GetAwaiter().GetResult();
ApiClient.Dispose();
}
}
diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
index e897a0b40..cadbda6d1 100644
--- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
@@ -72,7 +72,7 @@ namespace Discord.WebSocket
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent);
- protected override async Task OnLoginAsync(TokenType tokenType, string token)
+ internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
if (_automaticShards)
{
@@ -95,7 +95,7 @@ namespace Discord.WebSocket
for (int i = 0; i < _shards.Length; i++)
await _shards[i].LoginAsync(tokenType, token, false);
}
- protected override async Task OnLogoutAsync()
+ internal override async Task OnLogoutAsync()
{
//Assume threadsafe: already in a connection lock
for (int i = 0; i < _shards.Length; i++)
@@ -112,42 +112,14 @@ namespace Discord.WebSocket
}
///
- public async Task ConnectAsync()
+ public async Task StartAsync()
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await ConnectInternalAsync().ConfigureAwait(false);
- }
- catch
- {
- await DisconnectInternalAsync().ConfigureAwait(false);
- throw;
- }
- finally { _connectionLock.Release(); }
- }
- private async Task ConnectInternalAsync()
- {
- await Task.WhenAll(
- _shards.Select(x => x.ConnectAsync())
- ).ConfigureAwait(false);
-
- CurrentUser = _shards[0].CurrentUser;
+ await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false);
}
///
- public async Task DisconnectAsync()
+ public async Task StopAsync()
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await DisconnectInternalAsync().ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
- }
- private async Task DisconnectInternalAsync()
- {
- for (int i = 0; i < _shards.Length; i++)
- await _shards[i].DisconnectAsync();
+ await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false);
}
public DiscordSocketClient GetShard(int id)
@@ -334,9 +306,6 @@ namespace Discord.WebSocket
}
//IDiscordClient
- Task IDiscordClient.ConnectAsync()
- => ConnectAsync();
-
async Task IDiscordClient.GetApplicationInfoAsync()
=> await GetApplicationInfoAsync().ConfigureAwait(false);
diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
index cbefd795c..7d680eaf2 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
@@ -28,6 +28,7 @@ namespace Discord.API
private CancellationTokenSource _connectCancelToken;
private string _gatewayUrl;
+ private bool _isExplicitUrl;
internal IWebSocketClient WebSocketClient { get; }
@@ -38,6 +39,8 @@ namespace Discord.API
: base(restClientProvider, userAgent, defaultRetryMode, serializer)
{
_gatewayUrl = url;
+ if (url != null)
+ _isExplicitUrl = true;
WebSocketClient = webSocketProvider();
//WebSocketClient.SetHeader("user-agent", DiscordConfig.UserAgent); (Causes issues in .NET Framework 4.6+)
WebSocketClient.BinaryMessage += async (data, index, count) =>
@@ -52,7 +55,8 @@ namespace Discord.API
using (var jsonReader = new JsonTextReader(reader))
{
var msg = _serializer.Deserialize(jsonReader);
- await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
+ if (msg != null)
+ await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
}
}
};
@@ -62,7 +66,8 @@ namespace Discord.API
using (var jsonReader = new JsonTextReader(reader))
{
var msg = _serializer.Deserialize(jsonReader);
- await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
+ if (msg != null)
+ await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false);
}
};
WebSocketClient.Closed += async ex =>
@@ -107,7 +112,7 @@ namespace Discord.API
if (WebSocketClient != null)
WebSocketClient.SetCancelToken(_connectCancelToken.Token);
- if (_gatewayUrl == null)
+ if (!_isExplicitUrl)
{
var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false);
_gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}";
@@ -118,7 +123,8 @@ namespace Discord.API
}
catch
{
- _gatewayUrl = null; //Uncache in case the gateway url changed
+ if (!_isExplicitUrl)
+ _gatewayUrl = null; //Uncache in case the gateway url changed
await DisconnectInternalAsync().ConfigureAwait(false);
throw;
}
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index c0608a868..092225376 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -17,29 +17,27 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using GameModel = Discord.API.Game;
-using Discord.Net;
namespace Discord.WebSocket
{
public partial class DiscordSocketClient : BaseDiscordClient, IDiscordClient
{
private readonly ConcurrentQueue _largeGuilds;
- private readonly Logger _gatewayLogger;
private readonly JsonSerializer _serializer;
private readonly SemaphoreSlim _connectionGroupLock;
private readonly DiscordSocketClient _parentClient;
private readonly ConcurrentQueue _heartbeatTimes;
+ private readonly ConnectionManager _connection;
+ private readonly Logger _gatewayLogger;
+ private readonly SemaphoreSlim _stateLock;
private string _sessionId;
private int _lastSeq;
private ImmutableDictionary _voiceRegions;
- private TaskCompletionSource _connectTask;
- private CancellationTokenSource _cancelToken, _reconnectCancelToken;
- private Task _heartbeatTask, _guildDownloadTask, _reconnectTask;
+ private Task _heartbeatTask, _guildDownloadTask;
private int _unavailableGuilds;
private long _lastGuildAvailableTime, _lastMessageTime;
private int _nextAudioId;
- private bool _canReconnect;
private DateTimeOffset? _statusSince;
private RestApplication _applicationInfo;
private ConcurrentHashSet _downloadUsersFor;
@@ -59,7 +57,6 @@ namespace Discord.WebSocket
internal int LargeThreshold { get; private set; }
internal AudioMode AudioMode { get; private set; }
internal ClientState State { get; private set; }
- internal int ConnectionTimeout { get; private set; }
internal UdpSocketProvider UdpSocketProvider { get; private set; }
internal WebSocketProvider WebSocketProvider { get; private set; }
internal bool AlwaysDownloadUsers { get; private set; }
@@ -90,35 +87,28 @@ namespace Discord.WebSocket
UdpSocketProvider = config.UdpSocketProvider;
WebSocketProvider = config.WebSocketProvider;
AlwaysDownloadUsers = config.AlwaysDownloadUsers;
- ConnectionTimeout = config.ConnectionTimeout;
State = new ClientState(0, 0);
_downloadUsersFor = new ConcurrentHashSet();
_heartbeatTimes = new ConcurrentQueue();
+
+ _stateLock = new SemaphoreSlim(1, 1);
+ _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : $"Shard #{ShardId}");
+ _connection = new ConnectionManager(_stateLock, _gatewayLogger, config.ConnectionTimeout,
+ OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_nextAudioId = 1;
- _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId);
_connectionGroupLock = groupLock;
_parentClient = parentClient;
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
_serializer.Error += (s, e) =>
{
- _gatewayLogger.WarningAsync(e.ErrorContext.Error).GetAwaiter().GetResult();
+ _gatewayLogger.WarningAsync("Serializer Error", e.ErrorContext.Error).GetAwaiter().GetResult();
e.ErrorContext.Handled = true;
};
ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false);
ApiClient.ReceivedGatewayEvent += ProcessMessageAsync;
- ApiClient.Disconnected += async ex =>
- {
- if (ex != null)
- {
- await _gatewayLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false);
- await StartReconnectAsync(ex).ConfigureAwait(false);
- }
- else
- await _gatewayLogger.WarningAsync($"Connection Closed").ConfigureAwait(false);
- };
LeftGuild += async g => await _gatewayLogger.InfoAsync($"Left {g.Name}").ConfigureAwait(false);
JoinedGuild += async g => await _gatewayLogger.InfoAsync($"Joined {g.Name}").ConfigureAwait(false);
@@ -143,8 +133,16 @@ namespace Discord.WebSocket
}
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config)
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost);
+ internal override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ StopAsync().GetAwaiter().GetResult();
+ ApiClient.Dispose();
+ }
+ }
- protected override async Task OnLoginAsync(TokenType tokenType, string token)
+ internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
if (_parentClient == null)
{
@@ -154,92 +152,49 @@ namespace Discord.WebSocket
else
_voiceRegions = _parentClient._voiceRegions;
}
- protected override async Task OnLogoutAsync()
+ internal override async Task OnLogoutAsync()
{
- if (ConnectionState != ConnectionState.Disconnected)
- await DisconnectInternalAsync(null, false).ConfigureAwait(false);
-
+ await StopAsync().ConfigureAwait(false);
_applicationInfo = null;
_voiceRegions = ImmutableDictionary.Create();
_downloadUsersFor.Clear();
}
+
+ public async Task StartAsync()
+ => await _connection.StartAsync().ConfigureAwait(false);
+ public async Task StopAsync()
+ => await _connection.StopAsync().ConfigureAwait(false);
- ///
- public async Task ConnectAsync()
+ private async Task OnConnectingAsync()
{
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await ConnectInternalAsync(false).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
- }
- private async Task ConnectInternalAsync(bool isReconnecting)
- {
- if (LoginState != LoginState.LoggedIn)
- throw new InvalidOperationException("Client is not logged in.");
-
- if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
- _reconnectCancelToken.Cancel();
-
- var state = ConnectionState;
- if (state == ConnectionState.Connecting || state == ConnectionState.Connected)
- await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
-
if (_connectionGroupLock != null)
- await _connectionGroupLock.WaitAsync().ConfigureAwait(false);
+ await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false);
try
{
- _canReconnect = true;
- ConnectionState = ConnectionState.Connecting;
- await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false);
-
- try
- {
- var connectTask = new TaskCompletionSource();
- _connectTask = connectTask;
- _cancelToken = new CancellationTokenSource();
-
- //Abort connection on timeout
- var _ = Task.Run(async () =>
- {
- await Task.Delay(ConnectionTimeout).ConfigureAwait(false);
- connectTask.TrySetException(new TimeoutException());
- });
-
- await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
- await ApiClient.ConnectAsync().ConfigureAwait(false);
- await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
- await _connectedEvent.InvokeAsync().ConfigureAwait(false);
-
- if (_sessionId != null)
- {
- await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
- await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
- }
- else
- {
- await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
- await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
- }
-
- await _connectTask.Task.ConfigureAwait(false);
+ await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
+ await ApiClient.ConnectAsync().ConfigureAwait(false);
+ await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
+ await _connectedEvent.InvokeAsync().ConfigureAwait(false);
- await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
- await SendStatusAsync().ConfigureAwait(false);
-
- await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false);
- ConnectionState = ConnectionState.Connected;
- await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false);
-
- await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x))
- .Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
+ if (_sessionId != null)
+ {
+ await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false);
+ await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false);
}
- catch (Exception)
+ else
{
- await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false);
- throw;
+ await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false);
+ await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false);
}
+
+ //Wait for READY
+ await _connection.WaitAsync().ConfigureAwait(false);
+
+ await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false);
+ await SendStatusAsync().ConfigureAwait(false);
+
+ await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x))
+ .Where(x => x != null).ToImmutableArray()).ConfigureAwait(false);
}
finally
{
@@ -250,41 +205,11 @@ namespace Discord.WebSocket
}
}
}
- ///
- public async Task DisconnectAsync()
- {
- if (_connectTask?.TrySetCanceled() ?? false) return;
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await DisconnectInternalAsync(null, false).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
- }
- private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting)
+ private async Task OnDisconnectingAsync(Exception ex)
{
- if (!isReconnecting)
- {
- _canReconnect = false;
- _sessionId = null;
- _lastSeq = 0;
-
- if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
- _reconnectCancelToken.Cancel();
- }
-
ulong guildId;
- if (ConnectionState == ConnectionState.Disconnected) return;
- ConnectionState = ConnectionState.Disconnecting;
- await _gatewayLogger.InfoAsync("Disconnecting").ConfigureAwait(false);
-
- await _gatewayLogger.DebugAsync("Cancelling current tasks").ConfigureAwait(false);
- //Signal tasks to complete
- try { _cancelToken.Cancel(); } catch { }
-
await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
- //Disconnect from server
await ApiClient.DisconnectAsync().ConfigureAwait(false);
//Wait for tasks to complete
@@ -294,8 +219,8 @@ namespace Discord.WebSocket
await heartbeatTask.ConfigureAwait(false);
_heartbeatTask = null;
- long times;
- while (_heartbeatTimes.TryDequeue(out times)) { }
+ long time;
+ while (_heartbeatTimes.TryDequeue(out time)) { }
_lastMessageTime = 0;
await _gatewayLogger.DebugAsync("Waiting for guild downloader").ConfigureAwait(false);
@@ -315,70 +240,6 @@ namespace Discord.WebSocket
if (guild._available)
await _guildUnavailableEvent.InvokeAsync(guild).ConfigureAwait(false);
}
-
- ConnectionState = ConnectionState.Disconnected;
- await _gatewayLogger.InfoAsync("Disconnected").ConfigureAwait(false);
-
- await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false);
- }
-
- private async Task StartReconnectAsync(Exception ex)
- {
- if ((ex as WebSocketClosedException)?.CloseCode == 4004) //Bad Token
- {
- _canReconnect = false;
- _connectTask?.TrySetException(ex);
- await LogoutAsync().ConfigureAwait(false);
- return;
- }
-
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- if (!_canReconnect || _reconnectTask != null) return;
- _reconnectCancelToken = new CancellationTokenSource();
- _reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token);
- }
- finally { _connectionLock.Release(); }
- }
- private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken)
- {
- try
- {
- Random jitter = new Random();
- int nextReconnectDelay = 1000;
- while (true)
- {
- await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false);
- nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250);
- if (nextReconnectDelay > 60000)
- nextReconnectDelay = 60000;
-
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- if (cancelToken.IsCancellationRequested) return;
- await ConnectInternalAsync(true).ConfigureAwait(false);
- _reconnectTask = null;
- return;
- }
- catch (Exception ex2)
- {
- await _gatewayLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false);
- }
- finally { _connectionLock.Release(); }
- }
- }
- catch (OperationCanceledException)
- {
- await _connectionLock.WaitAsync().ConfigureAwait(false);
- try
- {
- await _gatewayLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false);
- _reconnectTask = null;
- }
- finally { _connectionLock.Release(); }
- }
}
///
@@ -555,7 +416,7 @@ namespace Discord.WebSocket
await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false);
var data = (payload as JToken).ToObject(_serializer);
- _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelToken.Token, _gatewayLogger);
+ _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.CancelToken);
}
break;
case GatewayOpCode.Heartbeat:
@@ -593,9 +454,7 @@ namespace Discord.WebSocket
case GatewayOpCode.Reconnect:
{
await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false);
- await _gatewayLogger.WarningAsync("Server requested a reconnect").ConfigureAwait(false);
-
- await StartReconnectAsync(new Exception("Server requested a reconnect")).ConfigureAwait(false);
+ _connection.Error(new Exception("Server requested a reconnect"));
}
break;
case GatewayOpCode.Dispatch:
@@ -633,8 +492,7 @@ namespace Discord.WebSocket
}
catch (Exception ex)
{
- _canReconnect = false;
- _connectTask.TrySetException(new Exception("Processing READY failed", ex));
+ _connection.CriticalError(new Exception("Processing READY failed", ex));
return;
}
@@ -642,11 +500,11 @@ namespace Discord.WebSocket
await SyncGuildsAsync().ConfigureAwait(false);
_lastGuildAvailableTime = Environment.TickCount;
- _guildDownloadTask = WaitForGuildsAsync(_cancelToken.Token, _gatewayLogger);
+ _guildDownloadTask = WaitForGuildsAsync(_connection.CancelToken, _gatewayLogger);
await _readyEvent.InvokeAsync().ConfigureAwait(false);
- var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
+ var _ = _connection.CompleteAsync();
await _gatewayLogger.InfoAsync("Ready").ConfigureAwait(false);
}
break;
@@ -654,7 +512,7 @@ namespace Discord.WebSocket
{
await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false);
- var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete
+ var _ = _connection.CompleteAsync();
//Notify the client that these guilds are available again
foreach (var guild in State.Guilds)
@@ -1356,7 +1214,6 @@ namespace Discord.WebSocket
SocketUserMessage cachedMsg = channel.GetCachedMessage(data.MessageId) as SocketUserMessage;
var user = await channel.GetUserAsync(data.UserId, CacheMode.CacheOnly);
SocketReaction reaction = SocketReaction.Create(data, channel, cachedMsg, Optional.Create(user));
-
if (cachedMsg != null)
{
cachedMsg.AddReaction(reaction);
@@ -1691,11 +1548,11 @@ namespace Discord.WebSocket
}
}
- private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken, Logger logger)
+ private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken)
{
try
{
- await logger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
+ await _gatewayLogger.DebugAsync("Heartbeat Started").ConfigureAwait(false);
while (!cancelToken.IsCancellationRequested)
{
var now = Environment.TickCount;
@@ -1705,8 +1562,7 @@ namespace Discord.WebSocket
{
if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? true))
{
- await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false);
- await StartReconnectAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false);
+ _connection.Error(new Exception("Server missed last heartbeat"));
return;
}
}
@@ -1718,20 +1574,20 @@ namespace Discord.WebSocket
}
catch (Exception ex)
{
- await logger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
+ await _gatewayLogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}
await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false);
}
- await logger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
+ await _gatewayLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
}
catch (OperationCanceledException)
{
- await logger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
+ await _gatewayLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false);
}
catch (Exception ex)
{
- await logger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
+ await _gatewayLogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false);
}
}
public async Task WaitForGuildsAsync()
@@ -1805,8 +1661,7 @@ namespace Discord.WebSocket
}
//IDiscordClient
- Task IDiscordClient.ConnectAsync()
- => ConnectAsync();
+ ConnectionState IDiscordClient.ConnectionState => _connection.State;
async Task IDiscordClient.GetApplicationInfoAsync()
=> await GetApplicationInfoAsync().ConfigureAwait(false);
@@ -1842,5 +1697,10 @@ namespace Discord.WebSocket
=> Task.FromResult>(VoiceRegions);
Task IDiscordClient.GetVoiceRegionAsync(string id)
=> Task.FromResult(GetVoiceRegion(id));
+
+ async Task IDiscordClient.StartAsync()
+ => await StartAsync().ConfigureAwait(false);
+ async Task IDiscordClient.StopAsync()
+ => await StopAsync().ConfigureAwait(false);
}
}
diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs
index 78a637d0e..f42744c79 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs
@@ -9,7 +9,7 @@ namespace Discord.WebSocket
{
public const string GatewayEncoding = "json";
- /// Gets or sets the websocket host to connect to. If null, the client will use the /gateway endpoint.
+ /// Gets or sets the websocket host to connect to. If null, the client will use the /gateway endpoint.
public string GatewayHost { get; set; } = null;
/// Gets or sets the time, in milliseconds, to wait for a connection to complete before aborting.
diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
index 22a4c2a71..4e16985a7 100644
--- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
+++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
@@ -501,7 +501,7 @@ namespace Discord.WebSocket
_audioConnectPromise?.TrySetCanceledAsync(); //Cancel any previous audio connection
_audioConnectPromise = null;
if (_audioClient != null)
- await _audioClient.DisconnectAsync().ConfigureAwait(false);
+ await _audioClient.StopAsync().ConfigureAwait(false);
_audioClient = null;
}
internal async Task FinishConnectAudio(int id, string url, string token)
@@ -517,7 +517,6 @@ namespace Discord.WebSocket
var promise = _audioConnectPromise;
audioClient.Disconnected += async ex =>
{
- //If the initial connection hasn't been made yet, reconnecting will lead to deadlocks
if (!promise.Task.IsCompleted)
{
try { audioClient.Dispose(); } catch { }
@@ -528,41 +527,15 @@ namespace Discord.WebSocket
await promise.TrySetCanceledAsync();
return;
}
-
- //TODO: Implement reconnect
- /*await _audioLock.WaitAsync().ConfigureAwait(false);
- try
- {
- if (AudioClient == audioClient) //Only reconnect if we're still assigned as this guild's audio client
- {
- if (ex != null)
- {
- //Reconnect if we still have channel info.
- //TODO: Is this threadsafe? Could channel data be deleted before we access it?
- var voiceState2 = GetVoiceState(Discord.CurrentUser.Id);
- if (voiceState2.HasValue)
- {
- var voiceChannelId = voiceState2.Value.VoiceChannel?.Id;
- if (voiceChannelId != null)
- {
- await Discord.ApiClient.SendVoiceStateUpdateAsync(Id, voiceChannelId, voiceState2.Value.IsSelfDeafened, voiceState2.Value.IsSelfMuted);
- return;
- }
- }
- }
- try { audioClient.Dispose(); } catch { }
- AudioClient = null;
- }
- }
- finally
- {
- _audioLock.Release();
- }*/
};
_audioClient = audioClient;
}
- await _audioClient.ConnectAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false);
- await _audioConnectPromise.TrySetResultAsync(_audioClient).ConfigureAwait(false);
+ _audioClient.Connected += () =>
+ {
+ var _ = _audioConnectPromise.TrySetResultAsync(_audioClient);
+ return Task.Delay(0);
+ };
+ await _audioClient.StartAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false);
}
catch (OperationCanceledException)
{