diff --git a/src/Discord.Net.Core/API/DiscordRestApiClient.cs b/src/Discord.Net.Core/API/DiscordRestApiClient.cs index baa0698d8..e2fcbf9af 100644 --- a/src/Discord.Net.Core/API/DiscordRestApiClient.cs +++ b/src/Discord.Net.Core/API/DiscordRestApiClient.cs @@ -6,13 +6,16 @@ using Discord.Net.Queue; using Discord.Net.Rest; using Newtonsoft.Json; using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; +using System.Linq.Expressions; using System.Net; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -21,9 +24,11 @@ namespace Discord.API { public class DiscordRestApiClient : IDisposable { + private static readonly ConcurrentDictionary> _bucketIdGenerators = new ConcurrentDictionary>(); + public event Func SentRequest { add { _sentRequestEvent.Add(value); } remove { _sentRequestEvent.Remove(value); } } private readonly AsyncEvent> _sentRequestEvent = new AsyncEvent>(); - + protected readonly JsonSerializer _serializer; protected readonly SemaphoreSlim _stateLock; private readonly RestClientProvider _restClientProvider; @@ -112,6 +117,7 @@ namespace Discord.API _restClient.SetCancelToken(_loginCancelToken.Token); AuthTokenType = tokenType; + RequestQueue.TokenType = tokenType; _authToken = token; _restClient.SetHeader("authorization", GetPrefixedToken(AuthTokenType, _authToken)); @@ -159,42 +165,61 @@ namespace Discord.API internal virtual Task DisconnectInternalAsync() => Task.CompletedTask; //Core - public async Task SendAsync(string method, string endpoint, RequestOptions options = null) + public async Task SendAsync(string method, string endpoint, string bucketId, RequestOptions options) { options.HeaderOnly = true; - var request = new RestRequest(_restClient, method, endpoint, options); + var request = new RestRequest(_restClient, method, endpoint, bucketId, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); } - public async Task SendJsonAsync(string method, string endpoint, object payload, RequestOptions options = null) + public async Task SendJsonAsync(string method, string endpoint, string bucketId, object payload, RequestOptions options) { options.HeaderOnly = true; var json = payload != null ? SerializeJson(payload) : null; - var request = new JsonRestRequest(_restClient, method, endpoint, json, options); + var request = new JsonRestRequest(_restClient, method, endpoint, bucketId, json, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); } - public async Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary multipartArgs, RequestOptions options = null) + public async Task SendMultipartAsync(string method, string endpoint, string bucketId, IReadOnlyDictionary multipartArgs, RequestOptions options) { options.HeaderOnly = true; - var request = new MultipartRestRequest(_restClient, method, endpoint, multipartArgs, options); + var request = new MultipartRestRequest(_restClient, method, endpoint, bucketId, multipartArgs, options); await SendInternalAsync(method, endpoint, request).ConfigureAwait(false); } - public async Task SendAsync(string method, string endpoint, RequestOptions options = null) where TResponse : class + public async Task SendAsync(string method, string endpoint, string bucketId, RequestOptions options) where TResponse : class { - var request = new RestRequest(_restClient, method, endpoint, options); + var request = new RestRequest(_restClient, method, endpoint, bucketId, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); } - public async Task SendJsonAsync(string method, string endpoint, object payload, RequestOptions options = null) where TResponse : class + public async Task SendJsonAsync(string method, string endpoint, string bucketId, object payload, RequestOptions options) where TResponse : class { var json = payload != null ? SerializeJson(payload) : null; - var request = new JsonRestRequest(_restClient, method, endpoint, json, options); + var request = new JsonRestRequest(_restClient, method, endpoint, bucketId, json, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); } - public async Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary multipartArgs, RequestOptions options = null) + public async Task SendMultipartAsync(string method, string endpoint, string bucketId, IReadOnlyDictionary multipartArgs, RequestOptions options) { - var request = new MultipartRestRequest(_restClient, method, endpoint, multipartArgs, options); + var request = new MultipartRestRequest(_restClient, method, endpoint, bucketId, multipartArgs, options); return DeserializeJson(await SendInternalAsync(method, endpoint, request).ConfigureAwait(false)); } - + + internal Task SendAsync(string method, Expression> endpointExpr, BucketIds ids, + RequestOptions options, [CallerMemberName] string funcName = null) + => SendAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), options); + internal Task SendJsonAsync(string method, Expression> endpointExpr, object payload, BucketIds ids, + RequestOptions options, [CallerMemberName] string funcName = null) + => SendJsonAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), payload, options); + internal Task SendMultipartAsync(string method, Expression> endpointExpr, IReadOnlyDictionary multipartArgs, BucketIds ids, + RequestOptions options, [CallerMemberName] string funcName = null) + => SendMultipartAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), multipartArgs, options); + internal Task SendAsync(string method, Expression> endpointExpr, BucketIds ids, + RequestOptions options, [CallerMemberName] string funcName = null) where TResponse : class + => SendAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), options); + internal Task SendJsonAsync(string method, Expression> endpointExpr, object payload, BucketIds ids, + RequestOptions options, [CallerMemberName] string funcName = null) where TResponse : class + => SendJsonAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), payload, options); + internal Task SendMultipartAsync(string method, Expression> endpointExpr, IReadOnlyDictionary multipartArgs, BucketIds ids, + RequestOptions options, [CallerMemberName] string funcName = null) + => SendMultipartAsync(method, GetEndpoint(endpointExpr), GetBucketId(ids, endpointExpr, funcName), multipartArgs, options); + private async Task SendInternalAsync(string method, string endpoint, RestRequest request) { if (!request.Options.IgnoreState) @@ -214,7 +239,7 @@ namespace Discord.API public async Task ValidateTokenAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - await SendAsync("GET", "auth/login", options: options).ConfigureAwait(false); + await SendAsync("GET", () => "auth/login", new BucketIds(), options: options).ConfigureAwait(false); } //Channels @@ -225,7 +250,8 @@ namespace Discord.API try { - return await SendAsync("GET", $"channels/{channelId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendAsync("GET", () => $"channels/{channelId}", ids, options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -237,7 +263,8 @@ namespace Discord.API try { - var model = await SendAsync("GET", $"channels/{channelId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + var model = await SendAsync("GET", () => $"channels/{channelId}", ids, options: options).ConfigureAwait(false); if (!model.GuildId.IsSpecified || model.GuildId.Value != guildId) return null; return model; @@ -249,7 +276,8 @@ namespace Discord.API Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"guilds/{guildId}/channels", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync>("GET", () => $"guilds/{guildId}/channels", ids, options: options).ConfigureAwait(false); } public async Task CreateGuildChannelAsync(ulong guildId, CreateGuildChannelParams args, RequestOptions options = null) { @@ -259,14 +287,16 @@ namespace Discord.API Preconditions.NotNullOrWhitespace(args.Name, nameof(args.Name)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("POST", $"guilds/{guildId}/channels", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("POST", () => $"guilds/{guildId}/channels", args, ids, options: options).ConfigureAwait(false); } public async Task DeleteChannelAsync(ulong channelId, RequestOptions options = null) { Preconditions.NotEqual(channelId, 0, nameof(channelId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("DELETE", $"channels/{channelId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendAsync("DELETE", () => $"channels/{channelId}", ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildChannelAsync(ulong channelId, ModifyGuildChannelParams args, RequestOptions options = null) { @@ -276,7 +306,8 @@ namespace Discord.API Preconditions.NotNullOrEmpty(args.Name, nameof(args.Name)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"channels/{channelId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendJsonAsync("PATCH", () => $"channels/{channelId}", args, ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildChannelAsync(ulong channelId, ModifyTextChannelParams args, RequestOptions options = null) { @@ -286,7 +317,8 @@ namespace Discord.API Preconditions.NotNullOrEmpty(args.Name, nameof(args.Name)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"channels/{channelId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendJsonAsync("PATCH", () => $"channels/{channelId}", args, ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildChannelAsync(ulong channelId, ModifyVoiceChannelParams args, RequestOptions options = null) { @@ -298,7 +330,8 @@ namespace Discord.API Preconditions.NotNullOrEmpty(args.Name, nameof(args.Name)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"channels/{channelId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendJsonAsync("PATCH", () => $"channels/{channelId}", args, ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildChannelsAsync(ulong guildId, IEnumerable args, RequestOptions options = null) { @@ -315,7 +348,8 @@ namespace Discord.API await ModifyGuildChannelAsync(channels[0].Id, new ModifyGuildChannelParams { Position = channels[0].Position }).ConfigureAwait(false); break; default: - await SendJsonAsync("PATCH", $"guilds/{guildId}/channels", channels, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendJsonAsync("PATCH", () => $"guilds/{guildId}/channels", channels, ids, options: options).ConfigureAwait(false); break; } } @@ -329,7 +363,8 @@ namespace Discord.API try { - return await SendAsync("GET", $"channels/{channelId}/messages/{messageId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendAsync("GET", () => $"channels/{channelId}/messages/{messageId}", ids, options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -359,12 +394,13 @@ namespace Discord.API break; } - string endpoint; + var ids = new BucketIds(channelId: channelId); + Expression> endpoint; if (relativeId != null) - endpoint = $"channels/{channelId}/messages?limit={limit}&{relativeDir}={relativeId}"; + endpoint = () => $"channels/{channelId}/messages?limit={limit}&{relativeDir}={relativeId}"; else - endpoint = $"channels/{channelId}/messages?limit={limit}"; - return await SendAsync>("GET", endpoint, options: options).ConfigureAwait(false); + endpoint = () =>$"channels/{channelId}/messages?limit={limit}"; + return await SendAsync>("GET", endpoint, ids, options: options).ConfigureAwait(false); } public async Task CreateMessageAsync(ulong channelId, CreateMessageParams args, RequestOptions options = null) { @@ -375,7 +411,8 @@ namespace Discord.API throw new ArgumentException($"Message content is too long, length must be less or equal to {DiscordConfig.MaxMessageSize}.", nameof(args.Content)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("POST", $"channels/{channelId}/messages", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendJsonAsync("POST", () => $"channels/{channelId}/messages", args, ids, options: options).ConfigureAwait(false); } public async Task UploadFileAsync(ulong channelId, UploadFileParams args, RequestOptions options = null) { @@ -393,7 +430,8 @@ namespace Discord.API throw new ArgumentOutOfRangeException($"Message content is too long, length must be less or equal to {DiscordConfig.MaxMessageSize}.", nameof(args.Content)); } - return await SendMultipartAsync("POST", $"channels/{channelId}/messages", args.ToDictionary(), options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendMultipartAsync("POST", () => $"channels/{channelId}/messages", args.ToDictionary(), ids, options: options).ConfigureAwait(false); } public async Task DeleteMessageAsync(ulong channelId, ulong messageId, RequestOptions options = null) { @@ -401,7 +439,8 @@ namespace Discord.API Preconditions.NotEqual(messageId, 0, nameof(messageId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"channels/{channelId}/messages/{messageId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("DELETE", () => $"channels/{channelId}/messages/{messageId}", ids, options: options).ConfigureAwait(false); } public async Task DeleteMessagesAsync(ulong channelId, DeleteMessagesParams args, RequestOptions options = null) { @@ -419,7 +458,8 @@ namespace Discord.API await DeleteMessageAsync(channelId, args.MessageIds[0]).ConfigureAwait(false); break; default: - await SendJsonAsync("POST", $"channels/{channelId}/messages/bulk-delete", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendJsonAsync("POST", () => $"channels/{channelId}/messages/bulk-delete", args, ids, options: options).ConfigureAwait(false); break; } } @@ -436,7 +476,8 @@ namespace Discord.API } options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"channels/{channelId}/messages/{messageId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendJsonAsync("PATCH", () => $"channels/{channelId}/messages/{messageId}", args, ids, options: options).ConfigureAwait(false); } public async Task AckMessageAsync(ulong channelId, ulong messageId, RequestOptions options = null) { @@ -444,14 +485,16 @@ namespace Discord.API Preconditions.NotEqual(messageId, 0, nameof(messageId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("POST", $"channels/{channelId}/messages/{messageId}/ack", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("POST", () => $"channels/{channelId}/messages/{messageId}/ack", ids, options: options).ConfigureAwait(false); } public async Task TriggerTypingIndicatorAsync(ulong channelId, RequestOptions options = null) { Preconditions.NotEqual(channelId, 0, nameof(channelId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("POST", $"channels/{channelId}/typing", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("POST", () => $"channels/{channelId}/typing", ids, options: options).ConfigureAwait(false); } //Channel Permissions @@ -462,7 +505,8 @@ namespace Discord.API Preconditions.NotNull(args, nameof(args)); options = RequestOptions.CreateOrClone(options); - await SendJsonAsync("PUT", $"channels/{channelId}/permissions/{targetId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendJsonAsync("PUT", () => $"channels/{channelId}/permissions/{targetId}", args, ids, options: options).ConfigureAwait(false); } public async Task DeleteChannelPermissionAsync(ulong channelId, ulong targetId, RequestOptions options = null) { @@ -470,7 +514,8 @@ namespace Discord.API Preconditions.NotEqual(targetId, 0, nameof(targetId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"channels/{channelId}/permissions/{targetId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("DELETE", () => $"channels/{channelId}/permissions/{targetId}", ids, options: options).ConfigureAwait(false); } //Channel Pins @@ -480,7 +525,8 @@ namespace Discord.API Preconditions.GreaterThan(messageId, 0, nameof(messageId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("PUT", $"channels/{channelId}/pins/{messageId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("PUT", () => $"channels/{channelId}/pins/{messageId}", ids, options: options).ConfigureAwait(false); } public async Task RemovePinAsync(ulong channelId, ulong messageId, RequestOptions options = null) @@ -489,14 +535,16 @@ namespace Discord.API Preconditions.NotEqual(messageId, 0, nameof(messageId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"channels/{channelId}/pins/{messageId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("DELETE", () => $"channels/{channelId}/pins/{messageId}", ids, options: options).ConfigureAwait(false); } public async Task> GetPinsAsync(ulong channelId, RequestOptions options = null) { Preconditions.NotEqual(channelId, 0, nameof(channelId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"channels/{channelId}/pins", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendAsync>("GET", () => $"channels/{channelId}/pins", ids, options: options).ConfigureAwait(false); } //Channel Recipients @@ -506,7 +554,8 @@ namespace Discord.API Preconditions.GreaterThan(userId, 0, nameof(userId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("PUT", $"channels/{channelId}/recipients/{userId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("PUT", () => $"channels/{channelId}/recipients/{userId}", ids, options: options).ConfigureAwait(false); } public async Task RemoveGroupRecipientAsync(ulong channelId, ulong userId, RequestOptions options = null) @@ -515,7 +564,8 @@ namespace Discord.API Preconditions.NotEqual(userId, 0, nameof(userId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"channels/{channelId}/recipients/{userId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + await SendAsync("DELETE", () => $"channels/{channelId}/recipients/{userId}", ids, options: options).ConfigureAwait(false); } //Guilds @@ -526,7 +576,8 @@ namespace Discord.API try { - return await SendAsync("GET", $"guilds/{guildId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("GET", () => $"guilds/{guildId}", ids, options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -536,22 +587,24 @@ namespace Discord.API Preconditions.NotNullOrWhitespace(args.Name, nameof(args.Name)); Preconditions.NotNullOrWhitespace(args.RegionId, nameof(args.RegionId)); options = RequestOptions.CreateOrClone(options); - - return await SendJsonAsync("POST", "guilds", args, options: options).ConfigureAwait(false); + + return await SendJsonAsync("POST", () => "guilds", args, new BucketIds(), options: options).ConfigureAwait(false); } public async Task DeleteGuildAsync(ulong guildId, RequestOptions options = null) { Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("DELETE", $"guilds/{guildId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("DELETE", () => $"guilds/{guildId}", ids, options: options).ConfigureAwait(false); } public async Task LeaveGuildAsync(ulong guildId, RequestOptions options = null) { Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("DELETE", $"users/@me/guilds/{guildId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("DELETE", () => $"users/@me/guilds/{guildId}", ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildAsync(ulong guildId, ModifyGuildParams args, RequestOptions options = null) { @@ -564,7 +617,8 @@ namespace Discord.API Preconditions.NotNull(args.RegionId, nameof(args.RegionId)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"guilds/{guildId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("PATCH", () => $"guilds/{guildId}", args, ids, options: options).ConfigureAwait(false); } public async Task BeginGuildPruneAsync(ulong guildId, GuildPruneParams args, RequestOptions options = null) { @@ -573,7 +627,8 @@ namespace Discord.API Preconditions.AtLeast(args.Days, 0, nameof(args.Days)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("POST", $"guilds/{guildId}/prune", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("POST", () => $"guilds/{guildId}/prune", args, ids, options: options).ConfigureAwait(false); } public async Task GetGuildPruneCountAsync(ulong guildId, GuildPruneParams args, RequestOptions options = null) { @@ -582,7 +637,8 @@ namespace Discord.API Preconditions.AtLeast(args.Days, 0, nameof(args.Days)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("GET", $"guilds/{guildId}/prune", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("GET", () => $"guilds/{guildId}/prune", args, ids, options: options).ConfigureAwait(false); } //Guild Bans @@ -591,7 +647,8 @@ namespace Discord.API Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"guilds/{guildId}/bans", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync>("GET", () => $"guilds/{guildId}/bans", ids, options: options).ConfigureAwait(false); } public async Task CreateGuildBanAsync(ulong guildId, ulong userId, CreateGuildBanParams args, RequestOptions options = null) { @@ -601,7 +658,8 @@ namespace Discord.API Preconditions.AtLeast(args.DeleteMessageDays, 0, nameof(args.DeleteMessageDays)); options = RequestOptions.CreateOrClone(options); - await SendAsync("PUT", $"guilds/{guildId}/bans/{userId}?delete-message-days={args.DeleteMessageDays}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendAsync("PUT", () => $"guilds/{guildId}/bans/{userId}?delete-message-days={args.DeleteMessageDays}", ids, options: options).ConfigureAwait(false); } public async Task RemoveGuildBanAsync(ulong guildId, ulong userId, RequestOptions options = null) { @@ -609,7 +667,8 @@ namespace Discord.API Preconditions.NotEqual(userId, 0, nameof(userId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"guilds/{guildId}/bans/{userId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendAsync("DELETE", () => $"guilds/{guildId}/bans/{userId}", ids, options: options).ConfigureAwait(false); } //Guild Embeds @@ -620,7 +679,8 @@ namespace Discord.API try { - return await SendAsync("GET", $"guilds/{guildId}/embed", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("GET", () => $"guilds/{guildId}/embed", ids, options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -630,7 +690,8 @@ namespace Discord.API Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"guilds/{guildId}/embed", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("PATCH", () => $"guilds/{guildId}/embed", args, ids, options: options).ConfigureAwait(false); } //Guild Integrations @@ -639,7 +700,8 @@ namespace Discord.API Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"guilds/{guildId}/integrations", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync>("GET", () => $"guilds/{guildId}/integrations", ids, options: options).ConfigureAwait(false); } public async Task CreateGuildIntegrationAsync(ulong guildId, CreateGuildIntegrationParams args, RequestOptions options = null) { @@ -648,7 +710,8 @@ namespace Discord.API Preconditions.NotEqual(args.Id, 0, nameof(args.Id)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("POST", $"guilds/{guildId}/integrations", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("POST", () => $"guilds/{guildId}/integrations", ids, options: options).ConfigureAwait(false); } public async Task DeleteGuildIntegrationAsync(ulong guildId, ulong integrationId, RequestOptions options = null) { @@ -656,7 +719,8 @@ namespace Discord.API Preconditions.NotEqual(integrationId, 0, nameof(integrationId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("DELETE", $"guilds/{guildId}/integrations/{integrationId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("DELETE", () => $"guilds/{guildId}/integrations/{integrationId}", ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildIntegrationAsync(ulong guildId, ulong integrationId, ModifyGuildIntegrationParams args, RequestOptions options = null) { @@ -667,7 +731,8 @@ namespace Discord.API Preconditions.AtLeast(args.ExpireGracePeriod, 0, nameof(args.ExpireGracePeriod)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"guilds/{guildId}/integrations/{integrationId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("PATCH", () => $"guilds/{guildId}/integrations/{integrationId}", args, ids, options: options).ConfigureAwait(false); } public async Task SyncGuildIntegrationAsync(ulong guildId, ulong integrationId, RequestOptions options = null) { @@ -675,7 +740,8 @@ namespace Discord.API Preconditions.NotEqual(integrationId, 0, nameof(integrationId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("POST", $"guilds/{guildId}/integrations/{integrationId}/sync", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("POST", () => $"guilds/{guildId}/integrations/{integrationId}/sync", ids, options: options).ConfigureAwait(false); } //Guild Invites @@ -694,7 +760,7 @@ namespace Discord.API try { - return await SendAsync("GET", $"invites/{inviteId}", options: options).ConfigureAwait(false); + return await SendAsync("GET", () => $"invites/{inviteId}", new BucketIds(), options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -703,14 +769,16 @@ namespace Discord.API Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"guilds/{guildId}/invites", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync>("GET", () => $"guilds/{guildId}/invites", ids, options: options).ConfigureAwait(false); } public async Task> GetChannelInvitesAsync(ulong channelId, RequestOptions options = null) { Preconditions.NotEqual(channelId, 0, nameof(channelId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"channels/{channelId}/invites", options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendAsync>("GET", () => $"channels/{channelId}/invites", ids, options: options).ConfigureAwait(false); } public async Task CreateChannelInviteAsync(ulong channelId, CreateChannelInviteParams args, RequestOptions options = null) { @@ -720,21 +788,22 @@ namespace Discord.API Preconditions.AtLeast(args.MaxUses, 0, nameof(args.MaxUses)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("POST", $"channels/{channelId}/invites", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(channelId: channelId); + return await SendJsonAsync("POST", () => $"channels/{channelId}/invites", args, ids, options: options).ConfigureAwait(false); } - public async Task DeleteInviteAsync(string inviteCode, RequestOptions options = null) + public async Task DeleteInviteAsync(string inviteId, RequestOptions options = null) { - Preconditions.NotNullOrEmpty(inviteCode, nameof(inviteCode)); + Preconditions.NotNullOrEmpty(inviteId, nameof(inviteId)); options = RequestOptions.CreateOrClone(options); - - return await SendAsync("DELETE", $"invites/{inviteCode}", options: options).ConfigureAwait(false); + + return await SendAsync("DELETE", () => $"invites/{inviteId}", new BucketIds(), options: options).ConfigureAwait(false); } - public async Task AcceptInviteAsync(string inviteCode, RequestOptions options = null) + public async Task AcceptInviteAsync(string inviteId, RequestOptions options = null) { - Preconditions.NotNullOrEmpty(inviteCode, nameof(inviteCode)); + Preconditions.NotNullOrEmpty(inviteId, nameof(inviteId)); options = RequestOptions.CreateOrClone(options); - - await SendAsync("POST", $"invites/{inviteCode}", options: options).ConfigureAwait(false); + + await SendAsync("POST", () => $"invites/{inviteId}", new BucketIds(), options: options).ConfigureAwait(false); } //Guild Members @@ -746,7 +815,8 @@ namespace Discord.API try { - return await SendAsync("GET", $"guilds/{guildId}/members/{userId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("GET", () => $"guilds/{guildId}/members/{userId}", ids, options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -761,9 +831,10 @@ namespace Discord.API int limit = args.Limit.GetValueOrDefault(int.MaxValue); ulong afterUserId = args.AfterUserId.GetValueOrDefault(0); - - string endpoint = $"guilds/{guildId}/members?limit={limit}&after={afterUserId}"; - return await SendAsync>("GET", endpoint, options: options).ConfigureAwait(false); + + var ids = new BucketIds(guildId: guildId); + Expression> endpoint = () => $"guilds/{guildId}/members?limit={limit}&after={afterUserId}"; + return await SendAsync>("GET", endpoint, ids, options: options).ConfigureAwait(false); } public async Task RemoveGuildMemberAsync(ulong guildId, ulong userId, RequestOptions options = null) { @@ -771,7 +842,8 @@ namespace Discord.API Preconditions.NotEqual(userId, 0, nameof(userId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"guilds/{guildId}/members/{userId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendAsync("DELETE", () => $"guilds/{guildId}/members/{userId}", ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildMemberAsync(ulong guildId, ulong userId, ModifyGuildMemberParams args, RequestOptions options = null) { @@ -790,7 +862,8 @@ namespace Discord.API } if (!isCurrentUser || args.Deaf.IsSpecified || args.Mute.IsSpecified || args.RoleIds.IsSpecified) { - await SendJsonAsync("PATCH", $"guilds/{guildId}/members/{userId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendJsonAsync("PATCH", () => $"guilds/{guildId}/members/{userId}", args, ids, options: options).ConfigureAwait(false); } } @@ -800,14 +873,16 @@ namespace Discord.API Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"guilds/{guildId}/roles", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync>("GET", () => $"guilds/{guildId}/roles", ids, options: options).ConfigureAwait(false); } public async Task CreateGuildRoleAsync(ulong guildId, RequestOptions options = null) { Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync("POST", $"guilds/{guildId}/roles", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync("POST", () => $"guilds/{guildId}/roles", ids, options: options).ConfigureAwait(false); } public async Task DeleteGuildRoleAsync(ulong guildId, ulong roleId, RequestOptions options = null) { @@ -815,7 +890,8 @@ namespace Discord.API Preconditions.NotEqual(roleId, 0, nameof(roleId)); options = RequestOptions.CreateOrClone(options); - await SendAsync("DELETE", $"guilds/{guildId}/roles/{roleId}", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendAsync("DELETE", () => $"guilds/{guildId}/roles/{roleId}", ids, options: options).ConfigureAwait(false); } public async Task ModifyGuildRoleAsync(ulong guildId, ulong roleId, ModifyGuildRoleParams args, RequestOptions options = null) { @@ -827,7 +903,8 @@ namespace Discord.API Preconditions.AtLeast(args.Position, 0, nameof(args.Position)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", $"guilds/{guildId}/roles/{roleId}", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync("PATCH", () => $"guilds/{guildId}/roles/{roleId}", args, ids, options: options).ConfigureAwait(false); } public async Task> ModifyGuildRolesAsync(ulong guildId, IEnumerable args, RequestOptions options = null) { @@ -843,7 +920,8 @@ namespace Discord.API case 1: return ImmutableArray.Create(await ModifyGuildRoleAsync(guildId, roles[0].Id, roles[0]).ConfigureAwait(false)); default: - return await SendJsonAsync>("PATCH", $"guilds/{guildId}/roles", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendJsonAsync>("PATCH", () => $"guilds/{guildId}/roles", args, ids, options: options).ConfigureAwait(false); } } @@ -855,7 +933,7 @@ namespace Discord.API try { - return await SendAsync("GET", $"users/{userId}", options: options).ConfigureAwait(false); + return await SendAsync("GET", () => $"users/{userId}", new BucketIds(), options: options).ConfigureAwait(false); } catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.NotFound) { return null; } } @@ -864,27 +942,27 @@ namespace Discord.API public async Task GetMyUserAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync("GET", "users/@me", options: options).ConfigureAwait(false); + return await SendAsync("GET", () => "users/@me", new BucketIds(), options: options).ConfigureAwait(false); } public async Task> GetMyConnectionsAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", "users/@me/connections", options: options).ConfigureAwait(false); + return await SendAsync>("GET", () => "users/@me/connections", new BucketIds(), options: options).ConfigureAwait(false); } public async Task> GetMyPrivateChannelsAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", "users/@me/channels", options: options).ConfigureAwait(false); + return await SendAsync>("GET", () => "users/@me/channels", new BucketIds(), options: options).ConfigureAwait(false); } public async Task> GetMyGuildsAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", "users/@me/guilds", options: options).ConfigureAwait(false); + return await SendAsync>("GET", () => "users/@me/guilds", new BucketIds(), options: options).ConfigureAwait(false); } public async Task GetMyApplicationAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync("GET", "oauth2/applications/@me", options: options).ConfigureAwait(false); + return await SendAsync("GET", () => "oauth2/applications/@me", new BucketIds(), options: options).ConfigureAwait(false); } public async Task ModifySelfAsync(ModifyCurrentUserParams args, RequestOptions options = null) { @@ -892,7 +970,7 @@ namespace Discord.API Preconditions.NotNullOrEmpty(args.Username, nameof(args.Username)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("PATCH", "users/@me", args, options: options).ConfigureAwait(false); + return await SendJsonAsync("PATCH", () => "users/@me", args, new BucketIds(), options: options).ConfigureAwait(false); } public async Task ModifyMyNickAsync(ulong guildId, ModifyCurrentUserNickParams args, RequestOptions options = null) { @@ -900,7 +978,8 @@ namespace Discord.API Preconditions.NotNull(args.Nickname, nameof(args.Nickname)); options = RequestOptions.CreateOrClone(options); - await SendJsonAsync("PATCH", $"guilds/{guildId}/members/@me/nick", args, options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + await SendJsonAsync("PATCH", () => $"guilds/{guildId}/members/@me/nick", args, ids, options: options).ConfigureAwait(false); } public async Task CreateDMChannelAsync(CreateDMChannelParams args, RequestOptions options = null) { @@ -908,21 +987,22 @@ namespace Discord.API Preconditions.GreaterThan(args.RecipientId, 0, nameof(args.RecipientId)); options = RequestOptions.CreateOrClone(options); - return await SendJsonAsync("POST", $"users/@me/channels", args, options: options).ConfigureAwait(false); + return await SendJsonAsync("POST", () => "users/@me/channels", args, new BucketIds(), options: options).ConfigureAwait(false); } //Voice Regions public async Task> GetVoiceRegionsAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", "voice/regions", options: options).ConfigureAwait(false); + return await SendAsync>("GET", () => "voice/regions", new BucketIds(), options: options).ConfigureAwait(false); } public async Task> GetGuildVoiceRegionsAsync(ulong guildId, RequestOptions options = null) { Preconditions.NotEqual(guildId, 0, nameof(guildId)); options = RequestOptions.CreateOrClone(options); - return await SendAsync>("GET", $"guilds/{guildId}/regions", options: options).ConfigureAwait(false); + var ids = new BucketIds(guildId: guildId); + return await SendAsync>("GET", () => $"guilds/{guildId}/regions", ids, options: options).ConfigureAwait(false); } //Helpers @@ -946,5 +1026,115 @@ namespace Discord.API using (JsonReader reader = new JsonTextReader(text)) return _serializer.Deserialize(reader); } + internal string GetBucketId(ulong guildId = 0, ulong channelId = 0, [CallerMemberName] string methodName = "") + { + if (guildId != 0) + { + if (channelId != 0) + return $"{methodName}({guildId}/{channelId})"; + else + return $"{methodName}({guildId})"; + } + else if (channelId != 0) + return $"{methodName}({channelId})"; + return $"{methodName}()"; + } + + internal class BucketIds + { + public ulong GuildId { get; } + public ulong ChannelId { get; } + + internal BucketIds(ulong guildId = 0, ulong channelId = 0) + { + GuildId = guildId; + ChannelId = channelId; + } + internal object[] ToArray() + => new object[] { GuildId, ChannelId }; + + internal static int? GetIndex(string name) + { + switch (name) + { + case "guildId": return 0; + case "channelId": return 1; + default: + return null; + } + } + } + + private string GetEndpoint(Expression> endpointExpr) + { + return endpointExpr.Compile()(); + } + private string GetBucketId(BucketIds ids, Expression> endpointExpr, string callingMethod) + { + return _bucketIdGenerators.GetOrAdd(callingMethod, x => CreateBucketId(endpointExpr))(ids); + } + + private Func CreateBucketId(Expression> endpoint) + { + try + { + //Is this a constant string? + if (endpoint.Body.NodeType == ExpressionType.Constant) + return x => (endpoint.Body as ConstantExpression).Value.ToString(); + + var builder = new StringBuilder(); + var methodCall = endpoint.Body as MethodCallExpression; + var methodArgs = methodCall.Arguments.ToArray(); + string format = (methodArgs[0] as ConstantExpression).Value as string; + + int endIndex = format.IndexOf('?'); //Dont include params + if (endIndex == -1) + endIndex = format.Length; + + int lastIndex = 0; + while (true) + { + int leftIndex = format.IndexOf("{", lastIndex); + if (leftIndex == -1 || leftIndex > endIndex) + { + builder.Append(format, lastIndex, endIndex - lastIndex); + break; + } + builder.Append(format, lastIndex, leftIndex); + int rightIndex = format.IndexOf("}", leftIndex); + + int argId = int.Parse(format.Substring(leftIndex + 1, rightIndex - leftIndex - 1)); + string fieldName = GetFieldName(methodArgs[argId + 1]); + int? mappedId; + + mappedId = BucketIds.GetIndex(fieldName); + if(!mappedId.HasValue && rightIndex != endIndex && format[rightIndex + 1] == '/') //Ignore the next slash + rightIndex++; + + if (mappedId.HasValue) + builder.Append($"{{{mappedId.Value}}}"); + + lastIndex = rightIndex + 1; + } + + format = builder.ToString(); + return x => string.Format(format, x.ToArray()); + } + catch (Exception ex) + { + throw new InvalidOperationException("Failed to generate the bucket id for this operation", ex); + } + } + + private static string GetFieldName(Expression expr) + { + if (expr.NodeType == ExpressionType.Convert) + expr = (expr as UnaryExpression).Operand; + + if (expr.NodeType != ExpressionType.MemberAccess) + throw new InvalidOperationException("Unsupported expression"); + + return (expr as MemberExpression).Member.Name; + } } } diff --git a/src/Discord.Net.Core/Logging/LogManager.cs b/src/Discord.Net.Core/Logging/LogManager.cs index 3b7a4b960..104e02835 100644 --- a/src/Discord.Net.Core/Logging/LogManager.cs +++ b/src/Discord.Net.Core/Logging/LogManager.cs @@ -1,5 +1,4 @@ using System; -using System.Runtime.InteropServices; using System.Threading.Tasks; namespace Discord.Logging diff --git a/src/Discord.Net.Core/Net/Queue/ClientBucket.cs b/src/Discord.Net.Core/Net/Queue/ClientBucket.cs new file mode 100644 index 000000000..93e5cfd23 --- /dev/null +++ b/src/Discord.Net.Core/Net/Queue/ClientBucket.cs @@ -0,0 +1,26 @@ +using System.Collections.Immutable; + +namespace Discord.Net.Queue +{ + public struct ClientBucket + { + private static readonly ImmutableDictionary _defs; + static ClientBucket() + { + var builder = ImmutableDictionary.CreateBuilder(); + builder.Add("", new ClientBucket(5, 5)); + _defs = builder.ToImmutable(); + } + + public static ClientBucket Get(string id) => _defs[id]; + + public int WindowCount { get; } + public int WindowSeconds { get; } + + public ClientBucket(int count, int seconds) + { + WindowCount = count; + WindowSeconds = seconds; + } + } +} diff --git a/src/Discord.Net.Core/Net/Queue/RequestQueue.cs b/src/Discord.Net.Core/Net/Queue/RequestQueue.cs index d25c1f340..28caca1c2 100644 --- a/src/Discord.Net.Core/Net/Queue/RequestQueue.cs +++ b/src/Discord.Net.Core/Net/Queue/RequestQueue.cs @@ -1,85 +1,127 @@ using System; using System.Collections.Concurrent; +using System.Diagnostics; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; namespace Discord.Net.Queue { - public class RequestQueue + public class RequestQueue : IDisposable { - public event Func RateLimitTriggered; - - private readonly SemaphoreSlim _lock; - private readonly ConcurrentDictionary _buckets; + public event Func RateLimitTriggered; + + internal TokenType TokenType { get; set; } + + private readonly ConcurrentDictionary _buckets; + private readonly SemaphoreSlim _tokenLock; private CancellationTokenSource _clearToken; private CancellationToken _parentToken; - private CancellationToken _cancelToken; + private CancellationToken _requestCancelToken; //Parent token + Clear token + private CancellationTokenSource _cancelToken; //Dispose token + private DateTimeOffset _waitUntil; + + private Task _cleanupTask; public RequestQueue() { - _lock = new SemaphoreSlim(1, 1); + _tokenLock = new SemaphoreSlim(1, 1); _clearToken = new CancellationTokenSource(); - _cancelToken = CancellationToken.None; + _cancelToken = new CancellationTokenSource(); + _requestCancelToken = CancellationToken.None; _parentToken = CancellationToken.None; + + _buckets = new ConcurrentDictionary(); - _buckets = new ConcurrentDictionary(); + _cleanupTask = RunCleanup(); } + public async Task SetCancelTokenAsync(CancellationToken cancelToken) { - await _lock.WaitAsync().ConfigureAwait(false); + await _tokenLock.WaitAsync().ConfigureAwait(false); try { _parentToken = cancelToken; - _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancelToken, _clearToken.Token).Token; + _requestCancelToken = CancellationTokenSource.CreateLinkedTokenSource(cancelToken, _clearToken.Token).Token; + } + finally { _tokenLock.Release(); } + } + public async Task ClearAsync() + { + await _tokenLock.WaitAsync().ConfigureAwait(false); + try + { + _clearToken?.Cancel(); + _clearToken = new CancellationTokenSource(); + if (_parentToken != null) + _requestCancelToken = CancellationTokenSource.CreateLinkedTokenSource(_clearToken.Token, _parentToken).Token; + else + _requestCancelToken = _clearToken.Token; } - finally { _lock.Release(); } + finally { _tokenLock.Release(); } } public async Task SendAsync(RestRequest request) { - request.CancelToken = _cancelToken; - var bucket = GetOrCreateBucket(request.Options.BucketId); + request.CancelToken = _requestCancelToken; + var bucket = GetOrCreateBucket(request.BucketId); return await bucket.SendAsync(request).ConfigureAwait(false); } - public async Task SendAsync(WebSocketRequest request) + public async Task SendAsync(WebSocketRequest request) { - request.CancelToken = _cancelToken; - var bucket = GetOrCreateBucket(request.Options.BucketId); - return await bucket.SendAsync(request).ConfigureAwait(false); + //TODO: Re-impl websocket buckets + request.CancelToken = _requestCancelToken; + await request.SendAsync().ConfigureAwait(false); } - - private RequestQueueBucket GetOrCreateBucket(string id) + + internal async Task EnterGlobalAsync(int id, RestRequest request) + { + int millis = (int)Math.Ceiling((_waitUntil - DateTimeOffset.UtcNow).TotalMilliseconds); + if (millis > 0) + { + Debug.WriteLine($"[{id}] Sleeping {millis} ms (Pre-emptive) [Global]"); + await Task.Delay(millis).ConfigureAwait(false); + } + } + internal void PauseGlobal(RateLimitInfo info, TimeSpan lag) { - return new RequestQueueBucket(this, id, null); + _waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + lag.TotalMilliseconds); } - public void DestroyBucket(string id) + private RequestBucket GetOrCreateBucket(string id) { - //Assume this object is locked - RequestQueueBucket bucket; - _buckets.TryRemove(id, out bucket); + return _buckets.GetOrAdd(id, x => new RequestBucket(this, x)); + } + internal async Task RaiseRateLimitTriggered(string bucketId, RateLimitInfo? info) + { + await RateLimitTriggered(bucketId, info).ConfigureAwait(false); } - public async Task ClearAsync() + private async Task RunCleanup() { - await _lock.WaitAsync().ConfigureAwait(false); try { - _clearToken?.Cancel(); - _clearToken = new CancellationTokenSource(); - if (_parentToken != null) - _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_clearToken.Token, _parentToken).Token; - else - _cancelToken = _clearToken.Token; + while (!_cancelToken.IsCancellationRequested) + { + var now = DateTimeOffset.UtcNow; + foreach (var bucket in _buckets.Select(x => x.Value)) + { + RequestBucket ignored; + if ((now - bucket.LastAttemptAt).TotalMinutes > 1.0) + _buckets.TryRemove(bucket.Id, out ignored); + } + await Task.Delay(60000, _cancelToken.Token); //Runs each minute + } } - finally { _lock.Release(); } + catch (OperationCanceledException) { } + catch (ObjectDisposedException) { } } - internal async Task RaiseRateLimitTriggered(string id, RequestQueueBucket bucket, int? millis) + public void Dispose() { - await RateLimitTriggered(id, bucket, millis).ConfigureAwait(false); + _cancelToken.Dispose(); } } } diff --git a/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs index a94d0f05e..211a68eab 100644 --- a/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs +++ b/src/Discord.Net.Core/Net/Queue/RequestQueueBucket.cs @@ -1,5 +1,7 @@ -#pragma warning disable CS4014 +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; +using System.Diagnostics; using System.IO; using System.Net; using System.Threading; @@ -7,145 +9,230 @@ using System.Threading.Tasks; namespace Discord.Net.Queue { - public class RequestQueueBucket + internal class RequestBucket { + private readonly object _lock; private readonly RequestQueue _queue; - private readonly SemaphoreSlim _semaphore; - private readonly object _pauseLock; - private int _pauseEndTick; - private TaskCompletionSource _resumeNotifier; + private int _semaphore; + private DateTimeOffset? _resetTick; - public string Id { get; } - public RequestQueueBucket Parent { get; } - public int WindowSeconds { get; } + public string Id { get; private set; } + public int WindowCount { get; private set; } + public DateTimeOffset LastAttemptAt { get; private set; } - public RequestQueueBucket(RequestQueue queue, string id, RequestQueueBucket parent = null) + public RequestBucket(RequestQueue queue, string id) { _queue = queue; Id = id; - _semaphore = new SemaphoreSlim(5, 5); - Parent = parent; - _pauseLock = new object(); - _resumeNotifier = new TaskCompletionSource(); - _resumeNotifier.SetResult(0); - } + _lock = new object(); - public async Task SendAsync(IQueuedRequest request) + if (queue.TokenType == TokenType.User) + WindowCount = ClientBucket.Get(Id).WindowCount; + else + WindowCount = 1; //Only allow one request until we get a header back + _semaphore = WindowCount; + _resetTick = null; + LastAttemptAt = DateTimeOffset.UtcNow; + } + + static int nextId = 0; + public async Task SendAsync(RestRequest request) { + int id = Interlocked.Increment(ref nextId); + Debug.WriteLine($"[{id}] Start"); + LastAttemptAt = DateTimeOffset.UtcNow; while (true) { - try + await _queue.EnterGlobalAsync(id, request).ConfigureAwait(false); + await EnterAsync(id, request).ConfigureAwait(false); + + Debug.WriteLine($"[{id}] Sending..."); + var response = await request.SendAsync().ConfigureAwait(false); + TimeSpan lag = DateTimeOffset.UtcNow - DateTimeOffset.Parse(response.Headers["Date"]); + var info = new RateLimitInfo(response.Headers); + + if (response.StatusCode < (HttpStatusCode)200 || response.StatusCode >= (HttpStatusCode)300) { - return await SendAsyncInternal(request).ConfigureAwait(false); + switch (response.StatusCode) + { + case (HttpStatusCode)429: + if (info.IsGlobal) + { + Debug.WriteLine($"[{id}] (!) 429 [Global]"); + _queue.PauseGlobal(info, lag); + } + else + { + Debug.WriteLine($"[{id}] (!) 429"); + Update(id, info, lag); + } + await _queue.RaiseRateLimitTriggered(Id, info).ConfigureAwait(false); + continue; //Retry + case HttpStatusCode.BadGateway: //502 + Debug.WriteLine($"[{id}] (!) 502"); + continue; //Continue + default: + string reason = null; + if (response.Stream != null) + { + try + { + using (var reader = new StreamReader(response.Stream)) + using (var jsonReader = new JsonTextReader(reader)) + { + var json = JToken.Load(jsonReader); + reason = json.Value("message"); + } + } + catch { } + } + throw new HttpException(response.StatusCode, reason); + } } - catch (HttpRateLimitException ex) + else { - //When a 429 occurs, we drop all our locks. - //This is generally safe though since 429s actually occuring should be very rare. - await _queue.RaiseRateLimitTriggered(Id, this, ex.RetryAfterMilliseconds).ConfigureAwait(false); - Pause(ex.RetryAfterMilliseconds); + Debug.WriteLine($"[{id}] Success"); + Update(id, info, lag); + Debug.WriteLine($"[{id}] Stop"); + return response.Stream; } } } - private async Task SendAsyncInternal(IQueuedRequest request) + + private async Task EnterAsync(int id, RestRequest request) { - var endTick = request.TimeoutTick; + int windowCount; + DateTimeOffset? resetAt; + bool isRateLimited = false; - //Wait until a spot is open in our bucket - if (_semaphore != null) - await EnterAsync(endTick).ConfigureAwait(false); - try + while (true) { - while (true) + if (DateTimeOffset.UtcNow > request.TimeoutAt || request.CancelToken.IsCancellationRequested) { - //Get our 429 state - Task notifier; - int resumeTime; - - lock (_pauseLock) - { - notifier = _resumeNotifier.Task; - resumeTime = _pauseEndTick; - } - - //Are we paused due to a 429? - if (!notifier.IsCompleted) - { - //If the 429 ends after the maximum time for this request, timeout immediately - if (endTick.HasValue && endTick.Value < resumeTime) - throw new TimeoutException(); + if (!isRateLimited) + throw new TimeoutException(); + else + throw new RateLimitedException(); + } - //Wait for the 429 to complete - await notifier.ConfigureAwait(false); - } + lock (_lock) + { + windowCount = WindowCount; + resetAt = _resetTick; + } - try + DateTimeOffset? timeoutAt = request.TimeoutAt; + if (windowCount > 0 && Interlocked.Decrement(ref _semaphore) < 0) + { + isRateLimited = true; + await _queue.RaiseRateLimitTriggered(Id, null).ConfigureAwait(false); + if (resetAt.HasValue) { - //If there's a parent bucket, pass this request to them - if (Parent != null) - return await Parent.SendAsyncInternal(request).ConfigureAwait(false); - - //We have all our semaphores, send the request - return await request.SendAsync().ConfigureAwait(false); + if (resetAt > timeoutAt) + throw new RateLimitedException(); + int millis = (int)Math.Ceiling((resetAt.Value - DateTimeOffset.UtcNow).TotalMilliseconds); + Debug.WriteLine($"[{id}] Sleeping {millis} ms (Pre-emptive)"); + if (millis > 0) + await Task.Delay(millis, request.CancelToken).ConfigureAwait(false); } - catch (HttpException ex) when (ex.StatusCode == HttpStatusCode.BadGateway) + else { - continue; + if ((timeoutAt.Value - DateTimeOffset.UtcNow).TotalMilliseconds < 500.0) + throw new RateLimitedException(); + Debug.WriteLine($"[{id}] Sleeping 500* ms (Pre-emptive)"); + await Task.Delay(500, request.CancelToken).ConfigureAwait(false); } + continue; } - } - finally - { - //Make sure we put this entry back after WindowMilliseconds - if (_semaphore != null) - QueueExitAsync(); + else + Debug.WriteLine($"[{id}] Entered Semaphore ({_semaphore}/{WindowCount} remaining)"); + break; } } - - private void Pause(int milliseconds) + + private void Update(int id, RateLimitInfo info, TimeSpan lag) { - lock (_pauseLock) + lock (_lock) { - //If we aren't already waiting on a 429's time, create a new notifier task - if (_resumeNotifier.Task.IsCompleted) + if (!info.Limit.HasValue && _queue.TokenType != TokenType.User) { - _resumeNotifier = new TaskCompletionSource(); - _pauseEndTick = unchecked(Environment.TickCount + milliseconds); - QueueResumeAsync(_resumeNotifier, milliseconds); + WindowCount = 0; + return; } - } - } - private async Task QueueResumeAsync(TaskCompletionSource resumeNotifier, int millis) - { - await Task.Delay(millis).ConfigureAwait(false); - resumeNotifier.TrySetResultAsync(0); - } - private async Task EnterAsync(int? endTick) - { - if (endTick.HasValue) - { - int millis = unchecked(endTick.Value - Environment.TickCount); - if (millis <= 0 || !await _semaphore.WaitAsync(millis).ConfigureAwait(false)) - throw new TimeoutException(); + bool hasQueuedReset = _resetTick != null; + if (info.Limit.HasValue && WindowCount != info.Limit.Value) + { + WindowCount = info.Limit.Value; + _semaphore = info.Remaining.Value; + Debug.WriteLine($"[{id}] Upgraded Semaphore to {info.Remaining.Value}/{WindowCount} "); + } + + var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + DateTimeOffset resetTick; - if (!await _semaphore.WaitAsync(0).ConfigureAwait(false)) + //Using X-RateLimit-Remaining causes a race condition + /*if (info.Remaining.HasValue) + { + Debug.WriteLine($"[{id}] X-RateLimit-Remaining: " + info.Remaining.Value); + _semaphore = info.Remaining.Value; + }*/ + if (info.RetryAfter.HasValue) { - await _queue.RaiseRateLimitTriggered(Id, this, null).ConfigureAwait(false); + //RetryAfter is more accurate than Reset, where available + resetTick = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value); + Debug.WriteLine($"[{id}] Retry-After: {info.RetryAfter.Value} ({info.RetryAfter.Value} ms)"); + } + else if (info.Reset.HasValue) + { + resetTick = info.Reset.Value.AddSeconds(/*1.0 +*/ lag.TotalSeconds); + int diff = (int)(resetTick - DateTimeOffset.UtcNow).TotalMilliseconds; + Debug.WriteLine($"[{id}] X-RateLimit-Reset: {info.Reset.Value.ToUnixTimeSeconds()} ({diff} ms, {lag.TotalMilliseconds} ms lag)"); + } + else if (_queue.TokenType == TokenType.User) + { + resetTick = DateTimeOffset.UtcNow.AddSeconds(ClientBucket.Get(Id).WindowSeconds); + Debug.WriteLine($"[{id}] Client Bucket: " + ClientBucket.Get(Id).WindowSeconds); + } - millis = unchecked(endTick.Value - Environment.TickCount); - if (millis <= 0 || !await _semaphore.WaitAsync(millis).ConfigureAwait(false)) - throw new TimeoutException(); + if (resetTick == null) + { + resetTick = DateTimeOffset.UtcNow.AddSeconds(1.0); //Forcibly reset in a second + Debug.WriteLine($"[{id}] Unknown Retry Time!"); + } + + if (!hasQueuedReset || resetTick > _resetTick) + { + _resetTick = resetTick; + LastAttemptAt = resetTick; //Make sure we dont destroy this until after its been reset + Debug.WriteLine($"[{id}] Reset in {(int)Math.Ceiling((resetTick - DateTimeOffset.UtcNow).TotalMilliseconds)} ms"); + + if (!hasQueuedReset) + { + var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds)); + } } } - else - await _semaphore.WaitAsync().ConfigureAwait(false); } - private async Task QueueExitAsync() + private async Task QueueReset(int id, int millis) { - await Task.Delay(WindowSeconds * 1000).ConfigureAwait(false); - _semaphore.Release(); + while (true) + { + if (millis > 0) + await Task.Delay(millis).ConfigureAwait(false); + lock (_lock) + { + millis = (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds); + if (millis <= 0) //Make sure we havent gotten a more accurate reset time + { + Debug.WriteLine($"[{id}] * Reset *"); + _semaphore = WindowCount; + _resetTick = null; + return; + } + } + } } } -} +} \ No newline at end of file diff --git a/src/Discord.Net.Core/Net/Queue/Requests/IQueuedRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/IQueuedRequest.cs deleted file mode 100644 index 492b3a77d..000000000 --- a/src/Discord.Net.Core/Net/Queue/Requests/IQueuedRequest.cs +++ /dev/null @@ -1,14 +0,0 @@ -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Discord.Net.Queue -{ - public interface IQueuedRequest - { - CancellationToken CancelToken { get; } - int? TimeoutTick { get; } - - Task SendAsync(); - } -} diff --git a/src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs new file mode 100644 index 000000000..c8d861a11 --- /dev/null +++ b/src/Discord.Net.Core/Net/Queue/Requests/IRequest.cs @@ -0,0 +1,12 @@ +using System; +using System.Threading; + +namespace Discord.Net.Queue +{ + public interface IRequest + { + CancellationToken CancelToken { get; } + DateTimeOffset? TimeoutAt { get; } + string BucketId { get; } + } +} diff --git a/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs index d715b790c..d328a3e26 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/JsonRestRequest.cs @@ -1,5 +1,4 @@ using Discord.Net.Rest; -using System.IO; using System.Threading.Tasks; namespace Discord.Net.Queue @@ -8,13 +7,13 @@ namespace Discord.Net.Queue { public string Json { get; } - public JsonRestRequest(IRestClient client, string method, string endpoint, string json, RequestOptions options) - : base(client, method, endpoint, options) + public JsonRestRequest(IRestClient client, string method, string endpoint, string bucket, string json, RequestOptions options) + : base(client, method, endpoint, bucket, options) { Json = json; } - public override async Task SendAsync() + public override async Task SendAsync() { return await Client.SendAsync(Method, Endpoint, Json, Options).ConfigureAwait(false); } diff --git a/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs index 047e5ed02..e27bb92a0 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/MultipartRestRequest.cs @@ -1,6 +1,5 @@ using Discord.Net.Rest; using System.Collections.Generic; -using System.IO; using System.Threading.Tasks; namespace Discord.Net.Queue @@ -9,13 +8,13 @@ namespace Discord.Net.Queue { public IReadOnlyDictionary MultipartParams { get; } - public MultipartRestRequest(IRestClient client, string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options) - : base(client, method, endpoint, options) + public MultipartRestRequest(IRestClient client, string method, string endpoint, string bucket, IReadOnlyDictionary multipartParams, RequestOptions options) + : base(client, method, endpoint, bucket, options) { MultipartParams = multipartParams; } - public override async Task SendAsync() + public override async Task SendAsync() { return await Client.SendAsync(Method, Endpoint, MultipartParams, Options).ConfigureAwait(false); } diff --git a/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs index 655a79567..8382003c8 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/RestRequest.cs @@ -6,29 +6,31 @@ using System.Threading.Tasks; namespace Discord.Net.Queue { - public class RestRequest : IQueuedRequest + public class RestRequest : IRequest { public IRestClient Client { get; } public string Method { get; } public string Endpoint { get; } - public int? TimeoutTick { get; } + public string BucketId { get; } + public DateTimeOffset? TimeoutAt { get; } public TaskCompletionSource Promise { get; } public RequestOptions Options { get; } public CancellationToken CancelToken { get; internal set; } - public RestRequest(IRestClient client, string method, string endpoint, RequestOptions options) + public RestRequest(IRestClient client, string method, string endpoint, string bucketId, RequestOptions options) { Preconditions.NotNull(options, nameof(options)); Client = client; Method = method; Endpoint = endpoint; + BucketId = bucketId; Options = options; - TimeoutTick = options.Timeout.HasValue ? (int?)unchecked(Environment.TickCount + options.Timeout.Value) : null; + TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; Promise = new TaskCompletionSource(); } - public virtual async Task SendAsync() + public virtual async Task SendAsync() { return await Client.SendAsync(Method, Endpoint, Options).ConfigureAwait(false); } diff --git a/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs b/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs index 796517c85..08cdb192c 100644 --- a/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs +++ b/src/Discord.Net.Core/Net/Queue/Requests/WebSocketRequest.cs @@ -6,32 +6,33 @@ using System.Threading.Tasks; namespace Discord.Net.Queue { - public class WebSocketRequest : IQueuedRequest + public class WebSocketRequest : IRequest { public IWebSocketClient Client { get; } + public string BucketId { get; } public byte[] Data { get; } public bool IsText { get; } - public int? TimeoutTick { get; } + public DateTimeOffset? TimeoutAt { get; } public TaskCompletionSource Promise { get; } public RequestOptions Options { get; } public CancellationToken CancelToken { get; internal set; } - public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, RequestOptions options) + public WebSocketRequest(IWebSocketClient client, string bucketId, byte[] data, bool isText, RequestOptions options) { Preconditions.NotNull(options, nameof(options)); Client = client; + BucketId = bucketId; Data = data; IsText = isText; Options = options; - TimeoutTick = options.Timeout.HasValue ? (int?)unchecked(Environment.TickCount + options.Timeout.Value) : null; + TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null; Promise = new TaskCompletionSource(); } - public async Task SendAsync() + public async Task SendAsync() { await Client.SendAsync(Data, 0, Data.Length, IsText).ConfigureAwait(false); - return null; } } } diff --git a/src/Discord.Net.Core/Net/RateLimitException.cs b/src/Discord.Net.Core/Net/RateLimitException.cs deleted file mode 100644 index cb0ca7f28..000000000 --- a/src/Discord.Net.Core/Net/RateLimitException.cs +++ /dev/null @@ -1,17 +0,0 @@ -using System.Net; - -namespace Discord.Net -{ - public class HttpRateLimitException : HttpException - { - public string Id { get; } - public int RetryAfterMilliseconds { get; } - - public HttpRateLimitException(string bucketId, int retryAfterMilliseconds, string reason) - : base((HttpStatusCode)429, reason) - { - Id = bucketId; - RetryAfterMilliseconds = retryAfterMilliseconds; - } - } -} diff --git a/src/Discord.Net.Core/Net/RateLimitInfo.cs b/src/Discord.Net.Core/Net/RateLimitInfo.cs new file mode 100644 index 000000000..2c2faccf8 --- /dev/null +++ b/src/Discord.Net.Core/Net/RateLimitInfo.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; + +namespace Discord.Net +{ + public struct RateLimitInfo + { + public bool IsGlobal { get; } + public int? Limit { get; } + public int? Remaining { get; } + public int? RetryAfter { get; } + public DateTimeOffset? Reset { get; } + + internal RateLimitInfo(Dictionary headers) + { + string temp; + IsGlobal = headers.TryGetValue("X-RateLimit-Global", out temp) ? bool.Parse(temp) : false; + Limit = headers.TryGetValue("X-RateLimit-Limit", out temp) ? int.Parse(temp) : (int?)null; + Remaining = headers.TryGetValue("X-RateLimit-Remaining", out temp) ? int.Parse(temp) : (int?)null; + Reset = headers.TryGetValue("X-RateLimit-Reset", out temp) ? DateTimeOffset.FromUnixTimeSeconds(int.Parse(temp)) : (DateTimeOffset?)null; + RetryAfter = headers.TryGetValue("Retry-After", out temp) ? int.Parse(temp) : (int?)null; + } + } +} diff --git a/src/Discord.Net.Core/Net/RateLimitedException.cs b/src/Discord.Net.Core/Net/RateLimitedException.cs new file mode 100644 index 000000000..e8572f911 --- /dev/null +++ b/src/Discord.Net.Core/Net/RateLimitedException.cs @@ -0,0 +1,12 @@ +using System; + +namespace Discord.Net +{ + public class RateLimitedException : TimeoutException + { + public RateLimitedException() + : base("You are being rate limited.") + { + } + } +} diff --git a/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs b/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs index b06df37b8..02c356efd 100644 --- a/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs +++ b/src/Discord.Net.Core/Net/Rest/DefaultRestClient.cs @@ -1,5 +1,4 @@ using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Globalization; @@ -67,13 +66,13 @@ namespace Discord.Net.Rest _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token; } - public async Task SendAsync(string method, string endpoint, RequestOptions options) + public async Task SendAsync(string method, string endpoint, RequestOptions options) { string uri = Path.Combine(_baseUrl, endpoint); using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) return await SendInternalAsync(restRequest, options).ConfigureAwait(false); } - public async Task SendAsync(string method, string endpoint, string json, RequestOptions options) + public async Task SendAsync(string method, string endpoint, string json, RequestOptions options) { string uri = Path.Combine(_baseUrl, endpoint); using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) @@ -82,7 +81,7 @@ namespace Discord.Net.Rest return await SendInternalAsync(restRequest, options).ConfigureAwait(false); } } - public async Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options) + public async Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options) { string uri = Path.Combine(_baseUrl, endpoint); using (var restRequest = new HttpRequestMessage(GetMethod(method), uri)) @@ -114,50 +113,17 @@ namespace Discord.Net.Rest } } - private async Task SendInternalAsync(HttpRequestMessage request, RequestOptions options) + private async Task SendInternalAsync(HttpRequestMessage request, RequestOptions options) { while (true) { var cancelToken = _cancelToken; //It's okay if another thread changes this, causes a retry to abort HttpResponseMessage response = await _client.SendAsync(request, cancelToken).ConfigureAwait(false); + + var headers = response.Headers.ToDictionary(x => x.Key, x => x.Value.FirstOrDefault()); + var stream = !options.HeaderOnly ? await response.Content.ReadAsStreamAsync().ConfigureAwait(false) : null; - int statusCode = (int)response.StatusCode; - if (statusCode < 200 || statusCode >= 300) //2xx = Success - { - string reason = null; - JToken content = null; - if (response.Content.Headers.GetValues("content-type").FirstOrDefault() == "application/json") - { - try - { - using (var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false)) - using (var reader = new StreamReader(stream)) - using (var json = new JsonTextReader(reader)) - { - content = _errorDeserializer.Deserialize(json); - reason = content.Value("message"); - if (reason == null) //Occasionally an error message is given under a different key because reasons - reason = content.ToString(Formatting.None); - } - } - catch { } //Might have been HTML Should we check for content-type? - } - - if (statusCode == 429 && content != null) - { - //TODO: Include bucket info - string bucketId = content.Value("bucket"); - int retryAfterMillis = content.Value("retry_after"); - throw new HttpRateLimitException(bucketId, retryAfterMillis, reason); - } - else - throw new HttpException(response.StatusCode, reason); - } - - if (options.HeaderOnly) - return null; - else - return await response.Content.ReadAsStreamAsync().ConfigureAwait(false); + return new RestResponse(response.StatusCode, headers, stream); } } diff --git a/src/Discord.Net.Core/Net/Rest/IRestClient.cs b/src/Discord.Net.Core/Net/Rest/IRestClient.cs index aa53bea5b..16cfbe62d 100644 --- a/src/Discord.Net.Core/Net/Rest/IRestClient.cs +++ b/src/Discord.Net.Core/Net/Rest/IRestClient.cs @@ -1,5 +1,5 @@ +using Discord.Net.Queue; using System.Collections.Generic; -using System.IO; using System.Threading; using System.Threading.Tasks; @@ -10,8 +10,8 @@ namespace Discord.Net.Rest void SetHeader(string key, string value); void SetCancelToken(CancellationToken cancelToken); - Task SendAsync(string method, string endpoint, RequestOptions options); - Task SendAsync(string method, string endpoint, string json, RequestOptions options); - Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options); + Task SendAsync(string method, string endpoint, RequestOptions options); + Task SendAsync(string method, string endpoint, string json, RequestOptions options); + Task SendAsync(string method, string endpoint, IReadOnlyDictionary multipartParams, RequestOptions options); } } diff --git a/src/Discord.Net.Core/Net/Rest/RestResponse.cs b/src/Discord.Net.Core/Net/Rest/RestResponse.cs new file mode 100644 index 000000000..412ff4dce --- /dev/null +++ b/src/Discord.Net.Core/Net/Rest/RestResponse.cs @@ -0,0 +1,20 @@ +using System.Collections.Generic; +using System.IO; +using System.Net; + +namespace Discord.Net.Rest +{ + public struct RestResponse + { + public HttpStatusCode StatusCode { get; } + public Dictionary Headers { get; } + public Stream Stream { get; } + + public RestResponse(HttpStatusCode statusCode, Dictionary headers, Stream stream) + { + StatusCode = statusCode; + Headers = headers; + Stream = stream; + } + } +} diff --git a/src/Discord.Net.Core/RequestOptions.cs b/src/Discord.Net.Core/RequestOptions.cs index 9c9986b29..1d362fad1 100644 --- a/src/Discord.Net.Core/RequestOptions.cs +++ b/src/Discord.Net.Core/RequestOptions.cs @@ -6,7 +6,6 @@ /// The max time, in milliseconds, to wait for this request to complete. If null, a request will not time out. If a rate limit has been triggered for this request's bucket and will not be unpaused in time, this request will fail immediately. public int? Timeout { get; set; } - public string BucketId { get; set; } public bool HeaderOnly { get; internal set; } internal bool IgnoreState { get; set; } diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs index 62dcfd055..4ed019dfb 100644 --- a/src/Discord.Net.Rest/BaseDiscordClient.cs +++ b/src/Discord.Net.Rest/BaseDiscordClient.cs @@ -40,11 +40,12 @@ namespace Discord.Rest _queueLogger = LogManager.CreateLogger("Queue"); _isFirstLogin = true; - ApiClient.RequestQueue.RateLimitTriggered += async (id, bucket, millis) => + ApiClient.RequestQueue.RateLimitTriggered += async (id, info) => { - await _queueLogger.WarningAsync($"Rate limit triggered (id = \"{id ?? "null"}\")").ConfigureAwait(false); - if (bucket == null && id != null) - await _queueLogger.WarningAsync($"Unknown rate limit bucket \"{id ?? "null"}\"").ConfigureAwait(false); + if (info == null) + await _queueLogger.WarningAsync($"Preemptive Rate limit triggered: {id ?? "null"}").ConfigureAwait(false); + else + await _queueLogger.WarningAsync($"Rate limit triggered: {id ?? "null"}").ConfigureAwait(false); }; ApiClient.SentRequest += async (method, endpoint, millis) => await _restLogger.VerboseAsync($"{method} {endpoint}: {millis} ms").ConfigureAwait(false); } diff --git a/src/Discord.Net.Rpc/API/DiscordRpcApiClient.cs b/src/Discord.Net.Rpc/API/DiscordRpcApiClient.cs index 67ead6f83..050783f28 100644 --- a/src/Discord.Net.Rpc/API/DiscordRpcApiClient.cs +++ b/src/Discord.Net.Rpc/API/DiscordRpcApiClient.cs @@ -233,7 +233,7 @@ namespace Discord.API var requestTracker = new RpcRequest(options); _requests[guid] = requestTracker; - await _requestQueue.SendAsync(new WebSocketRequest(_webSocketClient, bytes, true, options)).ConfigureAwait(false); + await _requestQueue.SendAsync(new WebSocketRequest(_webSocketClient, null, bytes, true, options)).ConfigureAwait(false); await _sentRpcMessageEvent.InvokeAsync(cmd).ConfigureAwait(false); return await requestTracker.Promise.Task.ConfigureAwait(false); } diff --git a/src/Discord.Net.WebSocket/API/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/API/DiscordSocketApiClient.cs index b1bb61eb2..f0dd5f852 100644 --- a/src/Discord.Net.WebSocket/API/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/API/DiscordSocketApiClient.cs @@ -167,7 +167,7 @@ namespace Discord.API payload = new SocketFrame { Operation = (int)opCode, Payload = payload }; if (payload != null) bytes = Encoding.UTF8.GetBytes(SerializeJson(payload)); - await RequestQueue.SendAsync(new WebSocketRequest(_gatewayClient, bytes, true, options)).ConfigureAwait(false); + await RequestQueue.SendAsync(new WebSocketRequest(_gatewayClient, null, bytes, true, options)).ConfigureAwait(false); await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false); } @@ -175,7 +175,7 @@ namespace Discord.API public async Task GetGatewayAsync(RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - return await SendAsync("GET", "gateway", options: options).ConfigureAwait(false); + return await SendAsync("GET", () => "gateway", new BucketIds(), options: options).ConfigureAwait(false); } public async Task SendIdentifyAsync(int largeThreshold = 100, bool useCompression = true, int shardID = 0, int totalShards = 1, RequestOptions options = null) {