diff --git a/src/Discord.Net.Core/Audio/IAudioClient.cs b/src/Discord.Net.Core/Audio/IAudioClient.cs index 472ad32f1..4a6ae2e27 100644 --- a/src/Discord.Net.Core/Audio/IAudioClient.cs +++ b/src/Discord.Net.Core/Audio/IAudioClient.cs @@ -14,7 +14,7 @@ namespace Discord.Audio /// Gets the estimated round-trip latency, in milliseconds, to the gateway server. int Latency { get; } - Task DisconnectAsync(); + Task StopAsync(); /// /// Creates a new outgoing stream accepting Opus-encoded data. diff --git a/src/Discord.Net.Core/IDiscordClient.cs b/src/Discord.Net.Core/IDiscordClient.cs index b9a08d32d..1c5ec41c1 100644 --- a/src/Discord.Net.Core/IDiscordClient.cs +++ b/src/Discord.Net.Core/IDiscordClient.cs @@ -10,8 +10,8 @@ namespace Discord ConnectionState ConnectionState { get; } ISelfUser CurrentUser { get; } - Task ConnectAsync(); - Task DisconnectAsync(); + Task StartAsync(); + Task StopAsync(); Task GetApplicationInfoAsync(); diff --git a/src/Discord.Net.Core/Logging/LogMessage.cs b/src/Discord.Net.Core/Logging/LogMessage.cs index 9c3dfcfea..d1b3782be 100644 --- a/src/Discord.Net.Core/Logging/LogMessage.cs +++ b/src/Discord.Net.Core/Logging/LogMessage.cs @@ -19,7 +19,7 @@ namespace Discord } public override string ToString() => ToString(null); - public string ToString(StringBuilder builder = null, bool fullException = true, bool prependTimestamp = true, DateTimeKind timestampKind = DateTimeKind.Local, int? padSource = 9) + public string ToString(StringBuilder builder = null, bool fullException = true, bool prependTimestamp = true, DateTimeKind timestampKind = DateTimeKind.Local, int? padSource = 11) { string sourceName = Source; string message = Message; @@ -87,8 +87,11 @@ namespace Discord } if (exMessage != null) { - builder.Append(':'); - builder.AppendLine(); + if (!string.IsNullOrEmpty(Message)) + { + builder.Append(':'); + builder.AppendLine(); + } builder.Append(exMessage); } diff --git a/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj b/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj index da0c8b0fd..829951d19 100644 --- a/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj +++ b/src/Discord.Net.DebugTools/Discord.Net.DebugTools.csproj @@ -6,7 +6,7 @@ netstandard1.6 Discord.Net.DebugTools RogueException - A Discord.Net extension adding random helper classes for diagnosing issues. + A Discord.Net extension adding some helper classes for diagnosing issues. discord;discordapp https://github.com/RogueException/Discord.Net http://opensource.org/licenses/MIT diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs index 8948c87dc..80c4cb598 100644 --- a/src/Discord.Net.Rest/BaseDiscordClient.cs +++ b/src/Discord.Net.Rest/BaseDiscordClient.cs @@ -18,10 +18,9 @@ namespace Discord.Rest public event Func LoggedOut { add { _loggedOutEvent.Add(value); } remove { _loggedOutEvent.Remove(value); } } private readonly AsyncEvent> _loggedOutEvent = new AsyncEvent>(); - internal readonly Logger _restLogger, _queueLogger; - internal readonly SemaphoreSlim _connectionLock; - private bool _isFirstLogin; - private bool _isDisposed; + internal readonly Logger _restLogger; + private readonly SemaphoreSlim _stateLock; + private bool _isFirstLogin, _isDisposed; internal API.DiscordRestApiClient ApiClient { get; } internal LogManager LogManager { get; } @@ -35,17 +34,16 @@ namespace Discord.Rest LogManager = new LogManager(config.LogLevel); LogManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false); - _connectionLock = new SemaphoreSlim(1, 1); + _stateLock = new SemaphoreSlim(1, 1); _restLogger = LogManager.CreateLogger("Rest"); - _queueLogger = LogManager.CreateLogger("Queue"); _isFirstLogin = config.DisplayInitialLog; ApiClient.RequestQueue.RateLimitTriggered += async (id, info) => { if (info == null) - await _queueLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false); + await _restLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false); else - await _queueLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false); + await _restLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false); }; ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false); } @@ -53,12 +51,12 @@ namespace Discord.Rest /// public async Task LoginAsync(TokenType tokenType, string token, bool validateToken = true) { - await _connectionLock.WaitAsync().ConfigureAwait(false); + await _stateLock.WaitAsync().ConfigureAwait(false); try { await LoginInternalAsync(tokenType, token).ConfigureAwait(false); } - finally { _connectionLock.Release(); } + finally { _stateLock.Release(); } } private async Task LoginInternalAsync(TokenType tokenType, string token) { @@ -86,17 +84,17 @@ namespace Discord.Rest await _loggedInEvent.InvokeAsync().ConfigureAwait(false); } - protected virtual Task OnLoginAsync(TokenType tokenType, string token) { return Task.Delay(0); } + internal virtual Task OnLoginAsync(TokenType tokenType, string token) { return Task.Delay(0); } /// public async Task LogoutAsync() { - await _connectionLock.WaitAsync().ConfigureAwait(false); + await _stateLock.WaitAsync().ConfigureAwait(false); try { await LogoutInternalAsync().ConfigureAwait(false); } - finally { _connectionLock.Release(); } + finally { _stateLock.Release(); } } private async Task LogoutInternalAsync() { @@ -111,7 +109,7 @@ namespace Discord.Rest await _loggedOutEvent.InvokeAsync().ConfigureAwait(false); } - protected virtual Task OnLogoutAsync() { return Task.Delay(0); } + internal virtual Task OnLogoutAsync() { return Task.Delay(0); } internal virtual void Dispose(bool disposing) { @@ -161,8 +159,9 @@ namespace Discord.Rest Task IDiscordClient.GetVoiceRegionAsync(string id) => Task.FromResult(null); - Task IDiscordClient.ConnectAsync() { throw new NotSupportedException(); } - Task IDiscordClient.DisconnectAsync() { throw new NotSupportedException(); } - + Task IDiscordClient.StartAsync() + => Task.Delay(0); + Task IDiscordClient.StopAsync() + => Task.Delay(0); } } diff --git a/src/Discord.Net.Rest/ConnectionManager.cs b/src/Discord.Net.Rest/ConnectionManager.cs new file mode 100644 index 000000000..ab1f4790c --- /dev/null +++ b/src/Discord.Net.Rest/ConnectionManager.cs @@ -0,0 +1,199 @@ +using Discord.Logging; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Discord +{ + internal class ConnectionManager + { + public event Func Connected { add { _connectedEvent.Add(value); } remove { _connectedEvent.Remove(value); } } + private readonly AsyncEvent> _connectedEvent = new AsyncEvent>(); + public event Func Disconnected { add { _disconnectedEvent.Add(value); } remove { _disconnectedEvent.Remove(value); } } + private readonly AsyncEvent> _disconnectedEvent = new AsyncEvent>(); + + private readonly SemaphoreSlim _stateLock; + private readonly Logger _logger; + private readonly int _connectionTimeout; + private readonly Func _onConnecting; + private readonly Func _onDisconnecting; + + private TaskCompletionSource _connectionPromise, _readyPromise; + private CancellationTokenSource _combinedCancelToken, _reconnectCancelToken, _connectionCancelToken; + private Task _task; + + public ConnectionState State { get; private set; } + public CancellationToken CancelToken { get; private set; } + + public bool IsCompleted => _readyPromise.Task.IsCompleted; + + internal ConnectionManager(SemaphoreSlim stateLock, Logger logger, int connectionTimeout, + Func onConnecting, Func onDisconnecting, Action> clientDisconnectHandler) + { + _stateLock = stateLock; + _logger = logger; + _connectionTimeout = connectionTimeout; + _onConnecting = onConnecting; + _onDisconnecting = onDisconnecting; + + clientDisconnectHandler(ex => + { + if (ex != null) + Error(new Exception("WebSocket connection was closed", ex)); + else + Error(new Exception("WebSocket connection was closed")); + return Task.Delay(0); + }); + } + + public virtual async Task StartAsync() + { + await AcquireConnectionLock().ConfigureAwait(false); + var reconnectCancelToken = new CancellationTokenSource(); + _reconnectCancelToken = new CancellationTokenSource(); + _task = Task.Run(async () => + { + try + { + Random jitter = new Random(); + int nextReconnectDelay = 1000; + while (!reconnectCancelToken.IsCancellationRequested) + { + try + { + await ConnectAsync(reconnectCancelToken).ConfigureAwait(false); + nextReconnectDelay = 1000; //Reset delay + await _connectionPromise.Task.ConfigureAwait(false); + } + catch (OperationCanceledException ex) + { + Cancel(); //In case this exception didn't come from another Error call + await DisconnectAsync(ex, !reconnectCancelToken.IsCancellationRequested).ConfigureAwait(false); + } + catch (Exception ex) + { + Error(ex); //In case this exception didn't come from another Error call + if (!reconnectCancelToken.IsCancellationRequested) + { + await _logger.WarningAsync(ex).ConfigureAwait(false); + await DisconnectAsync(ex, true).ConfigureAwait(false); + } + else + { + await _logger.ErrorAsync(ex).ConfigureAwait(false); + await DisconnectAsync(ex, false).ConfigureAwait(false); + } + } + + if (!reconnectCancelToken.IsCancellationRequested) + { + //Wait before reconnecting + await Task.Delay(nextReconnectDelay, reconnectCancelToken.Token).ConfigureAwait(false); + nextReconnectDelay = (nextReconnectDelay * 2) + jitter.Next(-250, 250); + if (nextReconnectDelay > 60000) + nextReconnectDelay = 60000; + } + } + } + finally { _stateLock.Release(); } + }); + } + public virtual async Task StopAsync() + { + Cancel(); + var task = _task; + if (task != null) + await task.ConfigureAwait(false); + } + + private async Task ConnectAsync(CancellationTokenSource reconnectCancelToken) + { + _connectionCancelToken = new CancellationTokenSource(); + _combinedCancelToken = CancellationTokenSource.CreateLinkedTokenSource(_connectionCancelToken.Token, reconnectCancelToken.Token); + CancelToken = _combinedCancelToken.Token; + + _connectionPromise = new TaskCompletionSource(); + State = ConnectionState.Connecting; + await _logger.InfoAsync("Connecting").ConfigureAwait(false); + + try + { + var readyPromise = new TaskCompletionSource(); + _readyPromise = readyPromise; + + //Abort connection on timeout + var cancelToken = CancelToken; + var _ = Task.Run(async () => + { + try + { + await Task.Delay(_connectionTimeout, cancelToken).ConfigureAwait(false); + readyPromise.TrySetException(new TimeoutException()); + } + catch (OperationCanceledException) { } + }); + + await _onConnecting().ConfigureAwait(false); + + await _logger.InfoAsync("Connected").ConfigureAwait(false); + State = ConnectionState.Connected; + await _logger.DebugAsync("Raising Event").ConfigureAwait(false); + await _connectedEvent.InvokeAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + Error(ex); + throw; + } + } + private async Task DisconnectAsync(Exception ex, bool isReconnecting) + { + if (State == ConnectionState.Disconnected) return; + State = ConnectionState.Disconnecting; + await _logger.InfoAsync("Disconnecting").ConfigureAwait(false); + + await _onDisconnecting(ex).ConfigureAwait(false); + + await _logger.InfoAsync("Disconnected").ConfigureAwait(false); + State = ConnectionState.Disconnected; + await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false); + } + + public async Task CompleteAsync() + { + await _readyPromise.TrySetResultAsync(true).ConfigureAwait(false); + } + public async Task WaitAsync() + { + await _readyPromise.Task.ConfigureAwait(false); + } + + public void Cancel() + { + _readyPromise?.TrySetCanceled(); + _connectionPromise?.TrySetCanceled(); + _reconnectCancelToken?.Cancel(); + _connectionCancelToken?.Cancel(); + } + public void Error(Exception ex) + { + _readyPromise.TrySetException(ex); + _connectionPromise.TrySetException(ex); + _connectionCancelToken?.Cancel(); + } + public void CriticalError(Exception ex) + { + _reconnectCancelToken?.Cancel(); + Error(ex); + } + private async Task AcquireConnectionLock() + { + while (true) + { + await StopAsync().ConfigureAwait(false); + if (await _stateLock.WaitAsync(0).ConfigureAwait(false)) + break; + } + } + } +} \ No newline at end of file diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs index 0727576bf..0ff1a4821 100644 --- a/src/Discord.Net.Rest/DiscordRestClient.cs +++ b/src/Discord.Net.Rest/DiscordRestClient.cs @@ -16,14 +16,19 @@ namespace Discord.Rest private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config) => new API.DiscordRestApiClient(config.RestClientProvider, DiscordRestConfig.UserAgent); + internal override void Dispose(bool disposing) + { + if (disposing) + ApiClient.Dispose(); + } - protected override async Task OnLoginAsync(TokenType tokenType, string token) + internal override async Task OnLoginAsync(TokenType tokenType, string token) { var user = await ApiClient.GetMyUserAsync(new RequestOptions { RetryMode = RetryMode.AlwaysRetry }).ConfigureAwait(false); ApiClient.CurrentUserId = user.Id; base.CurrentUser = RestSelfUser.Create(this, user); } - protected override Task OnLogoutAsync() + internal override Task OnLogoutAsync() { _applicationInfo = null; return Task.Delay(0); diff --git a/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs b/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs index 2a9ae21bf..d3c50a5ec 100644 --- a/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs +++ b/src/Discord.Net.Rpc/DiscordRpcClient.Events.cs @@ -12,12 +12,12 @@ namespace Discord.Rpc remove { _connectedEvent.Remove(value); } } private readonly AsyncEvent> _connectedEvent = new AsyncEvent>(); - public event Func Disconnected + public event Func Disconnected { add { _disconnectedEvent.Add(value); } remove { _disconnectedEvent.Remove(value); } } - private readonly AsyncEvent> _disconnectedEvent = new AsyncEvent>(); + private readonly AsyncEvent> _disconnectedEvent = new AsyncEvent>(); public event Func Ready { add { _readyEvent.Add(value); } diff --git a/src/Discord.Net.Rpc/DiscordRpcClient.cs b/src/Discord.Net.Rpc/DiscordRpcClient.cs index 845ba97c6..5235c98d4 100644 --- a/src/Discord.Net.Rpc/DiscordRpcClient.cs +++ b/src/Discord.Net.Rpc/DiscordRpcClient.cs @@ -8,28 +8,22 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; -using System.Threading; using System.Threading.Tasks; +using System.Threading; namespace Discord.Rpc { public partial class DiscordRpcClient : BaseDiscordClient, IDiscordClient { - private readonly Logger _rpcLogger; private readonly JsonSerializer _serializer; - - private TaskCompletionSource _connectTask; - private CancellationTokenSource _cancelToken, _reconnectCancelToken; - private Task _reconnectTask; - private bool _canReconnect; + private readonly ConnectionManager _connection; + private readonly Logger _rpcLogger; + private readonly SemaphoreSlim _stateLock, _authorizeLock; public ConnectionState ConnectionState { get; private set; } public IReadOnlyCollection Scopes { get; private set; } public DateTimeOffset TokenExpiresAt { get; private set; } - //From DiscordRpcConfig - internal int ConnectionTimeout { get; private set; } - internal new API.DiscordRpcApiClient ApiClient => base.ApiClient as API.DiscordRpcApiClient; public new RestSelfUser CurrentUser { get { return base.CurrentUser as RestSelfUser; } private set { base.CurrentUser = value; } } public RestApplication ApplicationInfo { get; private set; } @@ -41,8 +35,11 @@ namespace Discord.Rpc public DiscordRpcClient(string clientId, string origin, DiscordRpcConfig config) : base(config, CreateApiClient(clientId, origin, config)) { - ConnectionTimeout = config.ConnectionTimeout; + _stateLock = new SemaphoreSlim(1, 1); + _authorizeLock = new SemaphoreSlim(1, 1); _rpcLogger = LogManager.CreateLogger("RPC"); + _connection = new ConnectionManager(_stateLock, _rpcLogger, config.ConnectionTimeout, + OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x); _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer.Error += (s, e) => @@ -53,177 +50,52 @@ namespace Discord.Rpc ApiClient.SentRpcMessage += async opCode => await _rpcLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false); ApiClient.ReceivedRpcEvent += ProcessMessageAsync; - ApiClient.Disconnected += async ex => - { - if (ex != null) - { - await _rpcLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false); - await StartReconnectAsync(ex).ConfigureAwait(false); - } - else - await _rpcLogger.WarningAsync($"Connection Closed").ConfigureAwait(false); - }; } private static API.DiscordRpcApiClient CreateApiClient(string clientId, string origin, DiscordRpcConfig config) => new API.DiscordRpcApiClient(clientId, DiscordRestConfig.UserAgent, origin, config.RestClientProvider, config.WebSocketProvider); - - /// - public async Task ConnectAsync() + internal override void Dispose(bool disposing) { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try + if (disposing) { - await ConnectInternalAsync(false).ConfigureAwait(false); + StopAsync().GetAwaiter().GetResult(); + ApiClient.Dispose(); } - finally { _connectionLock.Release(); } } - private async Task ConnectInternalAsync(bool isReconnecting) - { - if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested) - _reconnectCancelToken.Cancel(); - var state = ConnectionState; - if (state == ConnectionState.Connecting || state == ConnectionState.Connected) - await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); + public Task StartAsync() => _connection.StartAsync(); + public Task StopAsync() => _connection.StopAsync(); - ConnectionState = ConnectionState.Connecting; - await _rpcLogger.InfoAsync("Connecting").ConfigureAwait(false); - try - { - var connectTask = new TaskCompletionSource(); - _connectTask = connectTask; - _cancelToken = new CancellationTokenSource(); - - //Abort connection on timeout - var _ = Task.Run(async () => - { - await Task.Delay(ConnectionTimeout).ConfigureAwait(false); - connectTask.TrySetException(new TimeoutException()); - }); - - await ApiClient.ConnectAsync().ConfigureAwait(false); - await _connectedEvent.InvokeAsync().ConfigureAwait(false); - - await _connectTask.Task.ConfigureAwait(false); - if (!isReconnecting) - _canReconnect = true; - ConnectionState = ConnectionState.Connected; - await _rpcLogger.InfoAsync("Connected").ConfigureAwait(false); - } - catch (Exception) - { - await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); - throw; - } - } - /// - public async Task DisconnectAsync() + private async Task OnConnectingAsync() { - if (_connectTask?.TrySetCanceled() ?? false) return; - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await DisconnectInternalAsync(null, false).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } + await _rpcLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); + await ApiClient.ConnectAsync().ConfigureAwait(false); + + await _connection.WaitAsync().ConfigureAwait(false); } - private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting) + private async Task OnDisconnectingAsync(Exception ex) { - if (!isReconnecting) - { - _canReconnect = false; - - if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested) - _reconnectCancelToken.Cancel(); - } - - if (ConnectionState == ConnectionState.Disconnected) return; - ConnectionState = ConnectionState.Disconnecting; - await _rpcLogger.InfoAsync("Disconnecting").ConfigureAwait(false); - - await _rpcLogger.DebugAsync("Disconnecting - CancelToken").ConfigureAwait(false); - //Signal tasks to complete - try { _cancelToken.Cancel(); } catch { } - - await _rpcLogger.DebugAsync("Disconnecting - ApiClient").ConfigureAwait(false); - //Disconnect from server + await _rpcLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false); await ApiClient.DisconnectAsync().ConfigureAwait(false); - - ConnectionState = ConnectionState.Disconnected; - await _rpcLogger.InfoAsync("Disconnected").ConfigureAwait(false); - - await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false); } - private async Task StartReconnectAsync(Exception ex) + public async Task AuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null) { - await _connectionLock.WaitAsync().ConfigureAwait(false); + await _authorizeLock.WaitAsync().ConfigureAwait(false); try { - if (!_canReconnect || _reconnectTask != null) return; - _reconnectCancelToken = new CancellationTokenSource(); - _reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token); - } - finally { _connectionLock.Release(); } - } - private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken) - { - if (ex == null) - { - if (_connectTask?.TrySetCanceled() ?? false) return; + await _connection.StartAsync().ConfigureAwait(false); + await _connection.WaitAsync().ConfigureAwait(false); + var result = await ApiClient.SendAuthorizeAsync(scopes, rpcToken, options).ConfigureAwait(false); + await _connection.StopAsync().ConfigureAwait(false); + return result.Code; } - else + finally { - if (_connectTask?.TrySetException(ex) ?? false) return; - } - - try - { - Random jitter = new Random(); - int nextReconnectDelay = 1000; - while (true) - { - await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false); - nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250); - if (nextReconnectDelay > 60000) - nextReconnectDelay = 60000; - - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - if (cancelToken.IsCancellationRequested) return; - await ConnectInternalAsync(true).ConfigureAwait(false); - _reconnectTask = null; - return; - } - catch (Exception ex2) - { - await _rpcLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } - } - } - catch (OperationCanceledException) - { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await _rpcLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false); - _reconnectTask = null; - } - finally { _connectionLock.Release(); } + _authorizeLock.Release(); } } - public async Task AuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null) - { - await ConnectAsync().ConfigureAwait(false); - var result = await ApiClient.SendAuthorizeAsync(scopes, rpcToken, options).ConfigureAwait(false); - await DisconnectAsync().ConfigureAwait(false); - return result.Code; - } - public async Task SubscribeGlobal(RpcGlobalEvent evnt, RequestOptions options = null) { await ApiClient.SendGlobalSubscribeAsync(GetEventName(evnt), options).ConfigureAwait(false); @@ -439,8 +311,8 @@ namespace Discord.Rpc ApplicationInfo = RestApplication.Create(this, response.Application); Scopes = response.Scopes; TokenExpiresAt = response.Expires; - - var __ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete + + var __ = _connection.CompleteAsync(); await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false); } catch (Exception ex) @@ -452,7 +324,7 @@ namespace Discord.Rpc } else { - var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete + var _ = _connection.CompleteAsync(); await _rpcLogger.InfoAsync("Ready").ConfigureAwait(false); } } @@ -592,6 +464,13 @@ namespace Discord.Rpc } //IDiscordClient + ConnectionState IDiscordClient.ConnectionState => _connection.State; + Task IDiscordClient.GetApplicationInfoAsync() => Task.FromResult(ApplicationInfo); + + async Task IDiscordClient.StartAsync() + => await StartAsync().ConfigureAwait(false); + async Task IDiscordClient.StopAsync() + => await StopAsync().ConfigureAwait(false); } } diff --git a/src/Discord.Net.WebSocket/Audio/AudioClient.cs b/src/Discord.Net.WebSocket/Audio/AudioClient.cs index 5bedf1786..5404227f2 100644 --- a/src/Discord.Net.WebSocket/Audio/AudioClient.cs +++ b/src/Discord.Net.WebSocket/Audio/AudioClient.cs @@ -5,6 +5,7 @@ using Discord.WebSocket; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System; +using System.Collections.Concurrent; using System.Linq; using System.Text; using System.Threading; @@ -12,6 +13,7 @@ using System.Threading.Tasks; namespace Discord.Audio { + //TODO: Add audio reconnecting internal class AudioClient : IAudioClient, IDisposable { public event Func Connected @@ -34,34 +36,37 @@ namespace Discord.Audio private readonly AsyncEvent> _latencyUpdatedEvent = new AsyncEvent>(); private readonly Logger _audioLogger; - internal readonly SemaphoreSlim _connectionLock; private readonly JsonSerializer _serializer; + private readonly ConnectionManager _connection; + private readonly SemaphoreSlim _stateLock; + private readonly ConcurrentQueue _heartbeatTimes; - private TaskCompletionSource _connectTask; - private CancellationTokenSource _cancelTokenSource; private Task _heartbeatTask; - private long _heartbeatTime; - private string _url; + private long _lastMessageTime; + private string _url, _sessionId, _token; + private ulong _userId; private uint _ssrc; private byte[] _secretKey; - private bool _isDisposed; public SocketGuild Guild { get; } public DiscordVoiceAPIClient ApiClient { get; private set; } - public ConnectionState ConnectionState { get; private set; } public int Latency { get; private set; } private DiscordSocketClient Discord => Guild.Discord; + public ConnectionState ConnectionState => _connection.State; /// Creates a new REST/WebSocket discord client. internal AudioClient(SocketGuild guild, int id) { Guild = guild; - _audioLogger = Discord.LogManager.CreateLogger($"Audio #{id}"); - - _connectionLock = new SemaphoreSlim(1, 1); + _stateLock = new SemaphoreSlim(1, 1); + _connection = new ConnectionManager(_stateLock, _audioLogger, 30000, + OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x); + _heartbeatTimes = new ConcurrentQueue(); + _audioLogger = Discord.LogManager.CreateLogger($"Audio #{id}"); + _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer.Error += (s, e) => { @@ -76,83 +81,28 @@ namespace Discord.Audio //ApiClient.SentData += async bytes => await _audioLogger.DebugAsync($"Sent {bytes} Bytes").ConfigureAwait(false); ApiClient.ReceivedEvent += ProcessMessageAsync; ApiClient.ReceivedPacket += ProcessPacketAsync; - ApiClient.Disconnected += async ex => - { - if (ex != null) - await _audioLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false); - else - await _audioLogger.WarningAsync($"Connection Closed").ConfigureAwait(false); - }; LatencyUpdated += async (old, val) => await _audioLogger.VerboseAsync($"Latency = {val} ms").ConfigureAwait(false); } - /// - internal async Task ConnectAsync(string url, ulong userId, string sessionId, string token) + internal async Task StartAsync(string url, ulong userId, string sessionId, string token) { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await ConnectInternalAsync(url, userId, sessionId, token).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } + _url = url; + _userId = userId; + _sessionId = sessionId; + _token = token; + await _connection.StartAsync().ConfigureAwait(false); } - private async Task ConnectInternalAsync(string url, ulong userId, string sessionId, string token) - { - var state = ConnectionState; - if (state == ConnectionState.Connecting || state == ConnectionState.Connected) - await DisconnectInternalAsync(null).ConfigureAwait(false); - - ConnectionState = ConnectionState.Connecting; - await _audioLogger.InfoAsync("Connecting").ConfigureAwait(false); - try - { - _url = url; - _connectTask = new TaskCompletionSource(); - _cancelTokenSource = new CancellationTokenSource(); - - await ApiClient.ConnectAsync("wss://" + url).ConfigureAwait(false); - await ApiClient.SendIdentityAsync(userId, sessionId, token).ConfigureAwait(false); - await _connectTask.Task.ConfigureAwait(false); + public async Task StopAsync() + => await _connection.StopAsync().ConfigureAwait(false); - await _connectedEvent.InvokeAsync().ConfigureAwait(false); - ConnectionState = ConnectionState.Connected; - await _audioLogger.InfoAsync("Connected").ConfigureAwait(false); - } - catch (Exception) - { - await DisconnectInternalAsync(null).ConfigureAwait(false); - throw; - } - } - /// - public async Task DisconnectAsync() + private async Task OnConnectingAsync() { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await DisconnectInternalAsync(null).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } + await ApiClient.ConnectAsync("wss://" + _url).ConfigureAwait(false); + await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false); } - private async Task DisconnectAsync(Exception ex) + private async Task OnDisconnectingAsync(Exception ex) { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await DisconnectInternalAsync(ex).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } - } - private async Task DisconnectInternalAsync(Exception ex) - { - if (ConnectionState == ConnectionState.Disconnected) return; - ConnectionState = ConnectionState.Disconnecting; - await _audioLogger.InfoAsync("Disconnecting").ConfigureAwait(false); - - //Signal tasks to complete - try { _cancelTokenSource.Cancel(); } catch { } - //Disconnect from server await ApiClient.DisconnectAsync().ConfigureAwait(false); @@ -162,17 +112,17 @@ namespace Discord.Audio await heartbeatTask.ConfigureAwait(false); _heartbeatTask = null; - ConnectionState = ConnectionState.Disconnected; - await _audioLogger.InfoAsync("Disconnected").ConfigureAwait(false); - await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false); - await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false); + + long time; + while (_heartbeatTimes.TryDequeue(out time)) { } + _lastMessageTime = 0; } public AudioOutStream CreateOpusStream(int samplesPerFrame, int bufferMillis) { CheckSamplesPerFrame(samplesPerFrame); - var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _cancelTokenSource.Token); + var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _connection.CancelToken); return new RTPWriteStream(target, _secretKey, samplesPerFrame, _ssrc); } public AudioOutStream CreateDirectOpusStream(int samplesPerFrame) @@ -184,7 +134,7 @@ namespace Discord.Audio public AudioOutStream CreatePCMStream(int samplesPerFrame, int channels, int? bitrate, int bufferMillis) { CheckSamplesPerFrame(samplesPerFrame); - var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _cancelTokenSource.Token); + var target = new BufferedAudioTarget(ApiClient, samplesPerFrame, bufferMillis, _connection.CancelToken); return new OpusEncodeStream(target, _secretKey, channels, samplesPerFrame, _ssrc, bitrate); } public AudioOutStream CreateDirectPCMStream(int samplesPerFrame, int channels, int? bitrate) @@ -202,6 +152,8 @@ namespace Discord.Audio private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload) { + _lastMessageTime = Environment.TickCount; + try { switch (opCode) @@ -216,8 +168,7 @@ namespace Discord.Audio if (!data.Modes.Contains(DiscordVoiceAPIClient.Mode)) throw new InvalidOperationException($"Discord does not support {DiscordVoiceAPIClient.Mode}"); - _heartbeatTime = 0; - _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelTokenSource.Token); + _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.CancelToken); ApiClient.SetUdpEndpoint(_url, data.Port); await ApiClient.SendDiscoveryAsync(_ssrc).ConfigureAwait(false); @@ -234,19 +185,17 @@ namespace Discord.Audio _secretKey = data.SecretKey; await ApiClient.SendSetSpeaking(true).ConfigureAwait(false); - var _ = _connectTask.TrySetResultAsync(true); + var _ = _connection.CompleteAsync(); } break; case VoiceOpCode.HeartbeatAck: { await _audioLogger.DebugAsync("Received HeartbeatAck").ConfigureAwait(false); - var heartbeatTime = _heartbeatTime; - if (heartbeatTime != 0) + long time; + if (_heartbeatTimes.TryDequeue(out time)) { - int latency = (int)(Environment.TickCount - _heartbeatTime); - _heartbeatTime = 0; - + int latency = (int)(Environment.TickCount - time); int before = Latency; Latency = latency; @@ -267,7 +216,7 @@ namespace Discord.Audio } private async Task ProcessPacketAsync(byte[] packet) { - if (!_connectTask.Task.IsCompleted) + if (!_connection.IsCompleted) { if (packet.Length == 70) { @@ -291,33 +240,50 @@ namespace Discord.Audio //Clean this up when Discord's session patch is live try { + await _audioLogger.DebugAsync("Heartbeat Started").ConfigureAwait(false); while (!cancelToken.IsCancellationRequested) { + var now = Environment.TickCount; + + //Did server respond to our last heartbeat, or are we still receiving messages (long load?) + if (_heartbeatTimes.Count != 0 && (now - _lastMessageTime) > intervalMillis && + ConnectionState == ConnectionState.Connected) + { + _connection.Error(new Exception("Server missed last heartbeat")); + return; + } + _heartbeatTimes.Enqueue(now); + await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); - if (_heartbeatTime != 0) //Server never responded to our last heartbeat + try { - if (ConnectionState == ConnectionState.Connected) - { - await _audioLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false); - await DisconnectInternalAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false); - return; - } + await ApiClient.SendHeartbeatAsync().ConfigureAwait(false); } - else - _heartbeatTime = Environment.TickCount; - await ApiClient.SendHeartbeatAsync().ConfigureAwait(false); + catch (Exception ex) + { + await _audioLogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false); + } + + await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); } + await _audioLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false); + } + catch (OperationCanceledException) + { + await _audioLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false); + } + catch (Exception ex) + { + await _audioLogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false); } - catch (OperationCanceledException) { } } internal void Dispose(bool disposing) { - if (disposing && !_isDisposed) + if (disposing) { - _isDisposed = true; - DisconnectInternalAsync(null).GetAwaiter().GetResult(); + StopAsync().GetAwaiter().GetResult(); ApiClient.Dispose(); } } diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index e897a0b40..cadbda6d1 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -72,7 +72,7 @@ namespace Discord.WebSocket private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent); - protected override async Task OnLoginAsync(TokenType tokenType, string token) + internal override async Task OnLoginAsync(TokenType tokenType, string token) { if (_automaticShards) { @@ -95,7 +95,7 @@ namespace Discord.WebSocket for (int i = 0; i < _shards.Length; i++) await _shards[i].LoginAsync(tokenType, token, false); } - protected override async Task OnLogoutAsync() + internal override async Task OnLogoutAsync() { //Assume threadsafe: already in a connection lock for (int i = 0; i < _shards.Length; i++) @@ -112,42 +112,14 @@ namespace Discord.WebSocket } /// - public async Task ConnectAsync() + public async Task StartAsync() { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await ConnectInternalAsync().ConfigureAwait(false); - } - catch - { - await DisconnectInternalAsync().ConfigureAwait(false); - throw; - } - finally { _connectionLock.Release(); } - } - private async Task ConnectInternalAsync() - { - await Task.WhenAll( - _shards.Select(x => x.ConnectAsync()) - ).ConfigureAwait(false); - - CurrentUser = _shards[0].CurrentUser; + await Task.WhenAll(_shards.Select(x => x.StartAsync())).ConfigureAwait(false); } /// - public async Task DisconnectAsync() + public async Task StopAsync() { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await DisconnectInternalAsync().ConfigureAwait(false); - } - finally { _connectionLock.Release(); } - } - private async Task DisconnectInternalAsync() - { - for (int i = 0; i < _shards.Length; i++) - await _shards[i].DisconnectAsync(); + await Task.WhenAll(_shards.Select(x => x.StopAsync())).ConfigureAwait(false); } public DiscordSocketClient GetShard(int id) @@ -334,9 +306,6 @@ namespace Discord.WebSocket } //IDiscordClient - Task IDiscordClient.ConnectAsync() - => ConnectAsync(); - async Task IDiscordClient.GetApplicationInfoAsync() => await GetApplicationInfoAsync().ConfigureAwait(false); diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index cbefd795c..7d680eaf2 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -28,6 +28,7 @@ namespace Discord.API private CancellationTokenSource _connectCancelToken; private string _gatewayUrl; + private bool _isExplicitUrl; internal IWebSocketClient WebSocketClient { get; } @@ -38,6 +39,8 @@ namespace Discord.API : base(restClientProvider, userAgent, defaultRetryMode, serializer) { _gatewayUrl = url; + if (url != null) + _isExplicitUrl = true; WebSocketClient = webSocketProvider(); //WebSocketClient.SetHeader("user-agent", DiscordConfig.UserAgent); (Causes issues in .NET Framework 4.6+) WebSocketClient.BinaryMessage += async (data, index, count) => @@ -52,7 +55,8 @@ namespace Discord.API using (var jsonReader = new JsonTextReader(reader)) { var msg = _serializer.Deserialize(jsonReader); - await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false); + if (msg != null) + await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false); } } }; @@ -62,7 +66,8 @@ namespace Discord.API using (var jsonReader = new JsonTextReader(reader)) { var msg = _serializer.Deserialize(jsonReader); - await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false); + if (msg != null) + await _receivedGatewayEvent.InvokeAsync((GatewayOpCode)msg.Operation, msg.Sequence, msg.Type, msg.Payload).ConfigureAwait(false); } }; WebSocketClient.Closed += async ex => @@ -107,7 +112,7 @@ namespace Discord.API if (WebSocketClient != null) WebSocketClient.SetCancelToken(_connectCancelToken.Token); - if (_gatewayUrl == null) + if (!_isExplicitUrl) { var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false); _gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}"; @@ -118,7 +123,8 @@ namespace Discord.API } catch { - _gatewayUrl = null; //Uncache in case the gateway url changed + if (!_isExplicitUrl) + _gatewayUrl = null; //Uncache in case the gateway url changed await DisconnectInternalAsync().ConfigureAwait(false); throw; } diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index c0608a868..092225376 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -17,29 +17,27 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; using GameModel = Discord.API.Game; -using Discord.Net; namespace Discord.WebSocket { public partial class DiscordSocketClient : BaseDiscordClient, IDiscordClient { private readonly ConcurrentQueue _largeGuilds; - private readonly Logger _gatewayLogger; private readonly JsonSerializer _serializer; private readonly SemaphoreSlim _connectionGroupLock; private readonly DiscordSocketClient _parentClient; private readonly ConcurrentQueue _heartbeatTimes; + private readonly ConnectionManager _connection; + private readonly Logger _gatewayLogger; + private readonly SemaphoreSlim _stateLock; private string _sessionId; private int _lastSeq; private ImmutableDictionary _voiceRegions; - private TaskCompletionSource _connectTask; - private CancellationTokenSource _cancelToken, _reconnectCancelToken; - private Task _heartbeatTask, _guildDownloadTask, _reconnectTask; + private Task _heartbeatTask, _guildDownloadTask; private int _unavailableGuilds; private long _lastGuildAvailableTime, _lastMessageTime; private int _nextAudioId; - private bool _canReconnect; private DateTimeOffset? _statusSince; private RestApplication _applicationInfo; private ConcurrentHashSet _downloadUsersFor; @@ -59,7 +57,6 @@ namespace Discord.WebSocket internal int LargeThreshold { get; private set; } internal AudioMode AudioMode { get; private set; } internal ClientState State { get; private set; } - internal int ConnectionTimeout { get; private set; } internal UdpSocketProvider UdpSocketProvider { get; private set; } internal WebSocketProvider WebSocketProvider { get; private set; } internal bool AlwaysDownloadUsers { get; private set; } @@ -90,35 +87,28 @@ namespace Discord.WebSocket UdpSocketProvider = config.UdpSocketProvider; WebSocketProvider = config.WebSocketProvider; AlwaysDownloadUsers = config.AlwaysDownloadUsers; - ConnectionTimeout = config.ConnectionTimeout; State = new ClientState(0, 0); _downloadUsersFor = new ConcurrentHashSet(); _heartbeatTimes = new ConcurrentQueue(); + + _stateLock = new SemaphoreSlim(1, 1); + _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : $"Shard #{ShardId}"); + _connection = new ConnectionManager(_stateLock, _gatewayLogger, config.ConnectionTimeout, + OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x); _nextAudioId = 1; - _gatewayLogger = LogManager.CreateLogger(ShardId == 0 && TotalShards == 1 ? "Gateway" : "Shard #" + ShardId); _connectionGroupLock = groupLock; _parentClient = parentClient; _serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() }; _serializer.Error += (s, e) => { - _gatewayLogger.WarningAsync(e.ErrorContext.Error).GetAwaiter().GetResult(); + _gatewayLogger.WarningAsync("Serializer Error", e.ErrorContext.Error).GetAwaiter().GetResult(); e.ErrorContext.Handled = true; }; ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false); ApiClient.ReceivedGatewayEvent += ProcessMessageAsync; - ApiClient.Disconnected += async ex => - { - if (ex != null) - { - await _gatewayLogger.WarningAsync($"Connection Closed", ex).ConfigureAwait(false); - await StartReconnectAsync(ex).ConfigureAwait(false); - } - else - await _gatewayLogger.WarningAsync($"Connection Closed").ConfigureAwait(false); - }; LeftGuild += async g => await _gatewayLogger.InfoAsync($"Left {g.Name}").ConfigureAwait(false); JoinedGuild += async g => await _gatewayLogger.InfoAsync($"Joined {g.Name}").ConfigureAwait(false); @@ -143,8 +133,16 @@ namespace Discord.WebSocket } private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) => new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost); + internal override void Dispose(bool disposing) + { + if (disposing) + { + StopAsync().GetAwaiter().GetResult(); + ApiClient.Dispose(); + } + } - protected override async Task OnLoginAsync(TokenType tokenType, string token) + internal override async Task OnLoginAsync(TokenType tokenType, string token) { if (_parentClient == null) { @@ -154,92 +152,49 @@ namespace Discord.WebSocket else _voiceRegions = _parentClient._voiceRegions; } - protected override async Task OnLogoutAsync() + internal override async Task OnLogoutAsync() { - if (ConnectionState != ConnectionState.Disconnected) - await DisconnectInternalAsync(null, false).ConfigureAwait(false); - + await StopAsync().ConfigureAwait(false); _applicationInfo = null; _voiceRegions = ImmutableDictionary.Create(); _downloadUsersFor.Clear(); } + + public async Task StartAsync() + => await _connection.StartAsync().ConfigureAwait(false); + public async Task StopAsync() + => await _connection.StopAsync().ConfigureAwait(false); - /// - public async Task ConnectAsync() + private async Task OnConnectingAsync() { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await ConnectInternalAsync(false).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } - } - private async Task ConnectInternalAsync(bool isReconnecting) - { - if (LoginState != LoginState.LoggedIn) - throw new InvalidOperationException("Client is not logged in."); - - if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested) - _reconnectCancelToken.Cancel(); - - var state = ConnectionState; - if (state == ConnectionState.Connecting || state == ConnectionState.Connected) - await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); - if (_connectionGroupLock != null) - await _connectionGroupLock.WaitAsync().ConfigureAwait(false); + await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false); try { - _canReconnect = true; - ConnectionState = ConnectionState.Connecting; - await _gatewayLogger.InfoAsync("Connecting").ConfigureAwait(false); - - try - { - var connectTask = new TaskCompletionSource(); - _connectTask = connectTask; - _cancelToken = new CancellationTokenSource(); - - //Abort connection on timeout - var _ = Task.Run(async () => - { - await Task.Delay(ConnectionTimeout).ConfigureAwait(false); - connectTask.TrySetException(new TimeoutException()); - }); - - await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); - await ApiClient.ConnectAsync().ConfigureAwait(false); - await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); - await _connectedEvent.InvokeAsync().ConfigureAwait(false); - - if (_sessionId != null) - { - await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); - await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); - } - else - { - await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); - await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false); - } - - await _connectTask.Task.ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false); + await ApiClient.ConnectAsync().ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); + await _connectedEvent.InvokeAsync().ConfigureAwait(false); - await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); - await SendStatusAsync().ConfigureAwait(false); - - await _gatewayLogger.DebugAsync("Raising Event").ConfigureAwait(false); - ConnectionState = ConnectionState.Connected; - await _gatewayLogger.InfoAsync("Connected").ConfigureAwait(false); - - await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x)) - .Where(x => x != null).ToImmutableArray()).ConfigureAwait(false); + if (_sessionId != null) + { + await _gatewayLogger.DebugAsync("Resuming").ConfigureAwait(false); + await ApiClient.SendResumeAsync(_sessionId, _lastSeq).ConfigureAwait(false); } - catch (Exception) + else { - await DisconnectInternalAsync(null, isReconnecting).ConfigureAwait(false); - throw; + await _gatewayLogger.DebugAsync("Identifying").ConfigureAwait(false); + await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards).ConfigureAwait(false); } + + //Wait for READY + await _connection.WaitAsync().ConfigureAwait(false); + + await _gatewayLogger.DebugAsync("Sending Status").ConfigureAwait(false); + await SendStatusAsync().ConfigureAwait(false); + + await ProcessUserDownloadsAsync(_downloadUsersFor.Select(x => GetGuild(x)) + .Where(x => x != null).ToImmutableArray()).ConfigureAwait(false); } finally { @@ -250,41 +205,11 @@ namespace Discord.WebSocket } } } - /// - public async Task DisconnectAsync() - { - if (_connectTask?.TrySetCanceled() ?? false) return; - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await DisconnectInternalAsync(null, false).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } - } - private async Task DisconnectInternalAsync(Exception ex, bool isReconnecting) + private async Task OnDisconnectingAsync(Exception ex) { - if (!isReconnecting) - { - _canReconnect = false; - _sessionId = null; - _lastSeq = 0; - - if (_reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested) - _reconnectCancelToken.Cancel(); - } - ulong guildId; - if (ConnectionState == ConnectionState.Disconnected) return; - ConnectionState = ConnectionState.Disconnecting; - await _gatewayLogger.InfoAsync("Disconnecting").ConfigureAwait(false); - - await _gatewayLogger.DebugAsync("Cancelling current tasks").ConfigureAwait(false); - //Signal tasks to complete - try { _cancelToken.Cancel(); } catch { } - await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false); - //Disconnect from server await ApiClient.DisconnectAsync().ConfigureAwait(false); //Wait for tasks to complete @@ -294,8 +219,8 @@ namespace Discord.WebSocket await heartbeatTask.ConfigureAwait(false); _heartbeatTask = null; - long times; - while (_heartbeatTimes.TryDequeue(out times)) { } + long time; + while (_heartbeatTimes.TryDequeue(out time)) { } _lastMessageTime = 0; await _gatewayLogger.DebugAsync("Waiting for guild downloader").ConfigureAwait(false); @@ -315,70 +240,6 @@ namespace Discord.WebSocket if (guild._available) await _guildUnavailableEvent.InvokeAsync(guild).ConfigureAwait(false); } - - ConnectionState = ConnectionState.Disconnected; - await _gatewayLogger.InfoAsync("Disconnected").ConfigureAwait(false); - - await _disconnectedEvent.InvokeAsync(ex).ConfigureAwait(false); - } - - private async Task StartReconnectAsync(Exception ex) - { - if ((ex as WebSocketClosedException)?.CloseCode == 4004) //Bad Token - { - _canReconnect = false; - _connectTask?.TrySetException(ex); - await LogoutAsync().ConfigureAwait(false); - return; - } - - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - if (!_canReconnect || _reconnectTask != null) return; - _reconnectCancelToken = new CancellationTokenSource(); - _reconnectTask = ReconnectInternalAsync(ex, _reconnectCancelToken.Token); - } - finally { _connectionLock.Release(); } - } - private async Task ReconnectInternalAsync(Exception ex, CancellationToken cancelToken) - { - try - { - Random jitter = new Random(); - int nextReconnectDelay = 1000; - while (true) - { - await Task.Delay(nextReconnectDelay, cancelToken).ConfigureAwait(false); - nextReconnectDelay = nextReconnectDelay * 2 + jitter.Next(-250, 250); - if (nextReconnectDelay > 60000) - nextReconnectDelay = 60000; - - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - if (cancelToken.IsCancellationRequested) return; - await ConnectInternalAsync(true).ConfigureAwait(false); - _reconnectTask = null; - return; - } - catch (Exception ex2) - { - await _gatewayLogger.WarningAsync("Reconnect failed", ex2).ConfigureAwait(false); - } - finally { _connectionLock.Release(); } - } - } - catch (OperationCanceledException) - { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - await _gatewayLogger.DebugAsync("Reconnect cancelled").ConfigureAwait(false); - _reconnectTask = null; - } - finally { _connectionLock.Release(); } - } } /// @@ -555,7 +416,7 @@ namespace Discord.WebSocket await _gatewayLogger.DebugAsync("Received Hello").ConfigureAwait(false); var data = (payload as JToken).ToObject(_serializer); - _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _cancelToken.Token, _gatewayLogger); + _heartbeatTask = RunHeartbeatAsync(data.HeartbeatInterval, _connection.CancelToken); } break; case GatewayOpCode.Heartbeat: @@ -593,9 +454,7 @@ namespace Discord.WebSocket case GatewayOpCode.Reconnect: { await _gatewayLogger.DebugAsync("Received Reconnect").ConfigureAwait(false); - await _gatewayLogger.WarningAsync("Server requested a reconnect").ConfigureAwait(false); - - await StartReconnectAsync(new Exception("Server requested a reconnect")).ConfigureAwait(false); + _connection.Error(new Exception("Server requested a reconnect")); } break; case GatewayOpCode.Dispatch: @@ -633,8 +492,7 @@ namespace Discord.WebSocket } catch (Exception ex) { - _canReconnect = false; - _connectTask.TrySetException(new Exception("Processing READY failed", ex)); + _connection.CriticalError(new Exception("Processing READY failed", ex)); return; } @@ -642,11 +500,11 @@ namespace Discord.WebSocket await SyncGuildsAsync().ConfigureAwait(false); _lastGuildAvailableTime = Environment.TickCount; - _guildDownloadTask = WaitForGuildsAsync(_cancelToken.Token, _gatewayLogger); + _guildDownloadTask = WaitForGuildsAsync(_connection.CancelToken, _gatewayLogger); await _readyEvent.InvokeAsync().ConfigureAwait(false); - var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete + var _ = _connection.CompleteAsync(); await _gatewayLogger.InfoAsync("Ready").ConfigureAwait(false); } break; @@ -654,7 +512,7 @@ namespace Discord.WebSocket { await _gatewayLogger.DebugAsync("Received Dispatch (RESUMED)").ConfigureAwait(false); - var _ = _connectTask.TrySetResultAsync(true); //Signal the .Connect() call to complete + var _ = _connection.CompleteAsync(); //Notify the client that these guilds are available again foreach (var guild in State.Guilds) @@ -1356,7 +1214,6 @@ namespace Discord.WebSocket SocketUserMessage cachedMsg = channel.GetCachedMessage(data.MessageId) as SocketUserMessage; var user = await channel.GetUserAsync(data.UserId, CacheMode.CacheOnly); SocketReaction reaction = SocketReaction.Create(data, channel, cachedMsg, Optional.Create(user)); - if (cachedMsg != null) { cachedMsg.AddReaction(reaction); @@ -1691,11 +1548,11 @@ namespace Discord.WebSocket } } - private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken, Logger logger) + private async Task RunHeartbeatAsync(int intervalMillis, CancellationToken cancelToken) { try { - await logger.DebugAsync("Heartbeat Started").ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Heartbeat Started").ConfigureAwait(false); while (!cancelToken.IsCancellationRequested) { var now = Environment.TickCount; @@ -1705,8 +1562,7 @@ namespace Discord.WebSocket { if (ConnectionState == ConnectionState.Connected && (_guildDownloadTask?.IsCompleted ?? true)) { - await _gatewayLogger.WarningAsync("Server missed last heartbeat").ConfigureAwait(false); - await StartReconnectAsync(new Exception("Server missed last heartbeat")).ConfigureAwait(false); + _connection.Error(new Exception("Server missed last heartbeat")); return; } } @@ -1718,20 +1574,20 @@ namespace Discord.WebSocket } catch (Exception ex) { - await logger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false); + await _gatewayLogger.WarningAsync("Heartbeat Errored", ex).ConfigureAwait(false); } await Task.Delay(intervalMillis, cancelToken).ConfigureAwait(false); } - await logger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false); } catch (OperationCanceledException) { - await logger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false); + await _gatewayLogger.DebugAsync("Heartbeat Stopped").ConfigureAwait(false); } catch (Exception ex) { - await logger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false); + await _gatewayLogger.ErrorAsync("Heartbeat Errored", ex).ConfigureAwait(false); } } public async Task WaitForGuildsAsync() @@ -1805,8 +1661,7 @@ namespace Discord.WebSocket } //IDiscordClient - Task IDiscordClient.ConnectAsync() - => ConnectAsync(); + ConnectionState IDiscordClient.ConnectionState => _connection.State; async Task IDiscordClient.GetApplicationInfoAsync() => await GetApplicationInfoAsync().ConfigureAwait(false); @@ -1842,5 +1697,10 @@ namespace Discord.WebSocket => Task.FromResult>(VoiceRegions); Task IDiscordClient.GetVoiceRegionAsync(string id) => Task.FromResult(GetVoiceRegion(id)); + + async Task IDiscordClient.StartAsync() + => await StartAsync().ConfigureAwait(false); + async Task IDiscordClient.StopAsync() + => await StopAsync().ConfigureAwait(false); } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs index 78a637d0e..f42744c79 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs @@ -9,7 +9,7 @@ namespace Discord.WebSocket { public const string GatewayEncoding = "json"; - /// Gets or sets the websocket host to connect to. If null, the client will use the /gateway endpoint. + /// Gets or sets the websocket host to connect to. If null, the client will use the /gateway endpoint. public string GatewayHost { get; set; } = null; /// Gets or sets the time, in milliseconds, to wait for a connection to complete before aborting. diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index 22a4c2a71..4e16985a7 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -501,7 +501,7 @@ namespace Discord.WebSocket _audioConnectPromise?.TrySetCanceledAsync(); //Cancel any previous audio connection _audioConnectPromise = null; if (_audioClient != null) - await _audioClient.DisconnectAsync().ConfigureAwait(false); + await _audioClient.StopAsync().ConfigureAwait(false); _audioClient = null; } internal async Task FinishConnectAudio(int id, string url, string token) @@ -517,7 +517,6 @@ namespace Discord.WebSocket var promise = _audioConnectPromise; audioClient.Disconnected += async ex => { - //If the initial connection hasn't been made yet, reconnecting will lead to deadlocks if (!promise.Task.IsCompleted) { try { audioClient.Dispose(); } catch { } @@ -528,41 +527,15 @@ namespace Discord.WebSocket await promise.TrySetCanceledAsync(); return; } - - //TODO: Implement reconnect - /*await _audioLock.WaitAsync().ConfigureAwait(false); - try - { - if (AudioClient == audioClient) //Only reconnect if we're still assigned as this guild's audio client - { - if (ex != null) - { - //Reconnect if we still have channel info. - //TODO: Is this threadsafe? Could channel data be deleted before we access it? - var voiceState2 = GetVoiceState(Discord.CurrentUser.Id); - if (voiceState2.HasValue) - { - var voiceChannelId = voiceState2.Value.VoiceChannel?.Id; - if (voiceChannelId != null) - { - await Discord.ApiClient.SendVoiceStateUpdateAsync(Id, voiceChannelId, voiceState2.Value.IsSelfDeafened, voiceState2.Value.IsSelfMuted); - return; - } - } - } - try { audioClient.Dispose(); } catch { } - AudioClient = null; - } - } - finally - { - _audioLock.Release(); - }*/ }; _audioClient = audioClient; } - await _audioClient.ConnectAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false); - await _audioConnectPromise.TrySetResultAsync(_audioClient).ConfigureAwait(false); + _audioClient.Connected += () => + { + var _ = _audioConnectPromise.TrySetResultAsync(_audioClient); + return Task.Delay(0); + }; + await _audioClient.StartAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false); } catch (OperationCanceledException) {