- Add session_start_limit to GetBotGatewayResponse - Add GetBotGatewayAsync to IDiscordClient - Add master/slave semaphores to enable concurrency - Not store semaphore name as static - Clone GatewayLimits when cloning the Configpull/1537/head
@@ -0,0 +1,18 @@ | |||
namespace Discord | |||
{ | |||
public class BotGateway | |||
{ | |||
/// <summary> | |||
/// The WSS URL that can be used for connecting to the gateway. | |||
/// </summary> | |||
public string Url { get; internal set; } | |||
/// <summary> | |||
/// The recommended number of shards to use when connecting. | |||
/// </summary> | |||
public int Shards { get; internal set; } | |||
/// <summary> | |||
/// Information on the current session start limit. | |||
/// </summary> | |||
public SessionStartLimit SessionStartLimit { get; internal set; } | |||
} | |||
} |
@@ -0,0 +1,22 @@ | |||
namespace Discord | |||
{ | |||
public class SessionStartLimit | |||
{ | |||
/// <summary> | |||
/// The total number of session starts the current user is allowed. | |||
/// </summary> | |||
public int Total { get; internal set; } | |||
/// <summary> | |||
/// The remaining number of session starts the current user is allowed. | |||
/// </summary> | |||
public int Remaining { get; internal set; } | |||
/// <summary> | |||
/// The number of milliseconds after which the limit resets. | |||
/// </summary> | |||
public int ResetAfter { get; internal set; } | |||
/// <summary> | |||
/// The maximum concurrent identify requests in a time window. | |||
/// </summary> | |||
public int MaxConcurrency { get; internal set; } | |||
} | |||
} |
@@ -274,5 +274,15 @@ namespace Discord | |||
/// that represents the number of shards that should be used with this account. | |||
/// </returns> | |||
Task<int> GetRecommendedShardCountAsync(RequestOptions options = null); | |||
/// <summary> | |||
/// Gets the gateway information related to the bot. | |||
/// </summary> | |||
/// <param name="options">The options to be used when sending the request.</param> | |||
/// <returns> | |||
/// A task that represents the asynchronous get operation. The task result contains a <see cref="BotGateway"/> | |||
/// that represents the gateway information related to the bot. | |||
/// </returns> | |||
Task<BotGateway> GetBotGatewayAsync(RequestOptions options = null); | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
using Newtonsoft.Json; | |||
namespace Discord.API.Rest | |||
{ | |||
internal class SessionStartLimit | |||
{ | |||
[JsonProperty("total")] | |||
public int Total { get; set; } | |||
[JsonProperty("remaining")] | |||
public int Remaining { get; set; } | |||
[JsonProperty("reset_after")] | |||
public int ResetAfter { get; set; } | |||
[JsonProperty("max_concurrency")] | |||
public int MaxConcurrency { get; set; } | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
#pragma warning disable CS1591 | |||
#pragma warning disable CS1591 | |||
using Newtonsoft.Json; | |||
namespace Discord.API.Rest | |||
@@ -9,5 +9,7 @@ namespace Discord.API.Rest | |||
public string Url { get; set; } | |||
[JsonProperty("shards")] | |||
public int Shards { get; set; } | |||
[JsonProperty("session_start_limit")] | |||
public SessionStartLimit SessionStartLimit { get; set; } | |||
} | |||
} |
@@ -152,6 +152,10 @@ namespace Discord.Rest | |||
public Task<int> GetRecommendedShardCountAsync(RequestOptions options = null) | |||
=> ClientHelper.GetRecommendShardCountAsync(this, options); | |||
/// <inheritdoc /> | |||
public Task<BotGateway> GetBotGatewayAsync(RequestOptions options = null) | |||
=> ClientHelper.GetBotGatewayAsync(this, options); | |||
//IDiscordClient | |||
/// <inheritdoc /> | |||
ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected; | |||
@@ -176,5 +176,22 @@ namespace Discord.Rest | |||
var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); | |||
return response.Shards; | |||
} | |||
public static async Task<BotGateway> GetBotGatewayAsync(BaseDiscordClient client, RequestOptions options) | |||
{ | |||
var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false); | |||
return new BotGateway | |||
{ | |||
Url = response.Url, | |||
Shards = response.Shards, | |||
SessionStartLimit = new SessionStartLimit | |||
{ | |||
Total = response.SessionStartLimit.Total, | |||
Remaining = response.SessionStartLimit.Remaining, | |||
ResetAfter = response.SessionStartLimit.ResetAfter, | |||
MaxConcurrency = response.SessionStartLimit.MaxConcurrency | |||
} | |||
}; | |||
} | |||
} | |||
} |
@@ -51,7 +51,7 @@ namespace Discord.API | |||
internal JsonSerializer Serializer => _serializer; | |||
/// <exception cref="ArgumentException">Unknown OAuth token type.</exception> | |||
public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, | |||
public DiscordRestApiClient(RestClientProvider restClientProvider, string userAgent, RequestQueue requestQueue, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, | |||
JsonSerializer serializer = null, RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, bool useSystemClock = true) | |||
{ | |||
_restClientProvider = restClientProvider; | |||
@@ -61,7 +61,7 @@ namespace Discord.API | |||
RateLimitPrecision = rateLimitPrecision; | |||
UseSystemClock = useSystemClock; | |||
RequestQueue = new RequestQueue(); | |||
RequestQueue = requestQueue ?? new RequestQueue(); | |||
_stateLock = new SemaphoreSlim(1, 1); | |||
SetBaseUrl(DiscordConfig.APIUrl); | |||
@@ -7,6 +7,11 @@ namespace Discord.Rest | |||
/// </summary> | |||
public class GatewayLimits | |||
{ | |||
/// <summary> | |||
/// Creates a new <see cref="GatewayLimits"/> with the default values. | |||
/// </summary> | |||
public static GatewayLimits Default => new GatewayLimits(); | |||
/// <summary> | |||
/// Gets or sets the global limits for the gateway rate limiter. | |||
/// </summary> | |||
@@ -15,6 +20,7 @@ namespace Discord.Rest | |||
/// and it is per websocket. | |||
/// </remarks> | |||
public GatewayLimit Global { get; set; } | |||
/// <summary> | |||
/// Gets or sets the limits of Identify requests. | |||
/// </summary> | |||
@@ -23,6 +29,7 @@ namespace Discord.Rest | |||
/// also per account. | |||
/// </remarks> | |||
public GatewayLimit Identify { get; set; } | |||
/// <summary> | |||
/// Gets or sets the limits of Presence Update requests. | |||
/// </summary> | |||
@@ -31,11 +38,35 @@ namespace Discord.Rest | |||
/// and status (online, idle, etc) | |||
/// </remarks> | |||
public GatewayLimit PresenceUpdate { get; set; } | |||
/// <summary> | |||
/// Gets or sets the name of the master <see cref="System.Threading.Semaphore"/> | |||
/// used by identify. | |||
/// </summary> | |||
/// <remarks> | |||
/// It is used to define what slave <see cref="System.Threading.Semaphore"/> | |||
/// is free to run for concurrent identify requests. | |||
/// </remarks> | |||
public string IdentifyMasterSemaphoreName { get; set; } | |||
/// <summary> | |||
/// Gets or sets the name of the <see cref="System.Threading.Semaphore"/> used by identify. | |||
/// Gets or sets the name of the slave <see cref="System.Threading.Semaphore"/> | |||
/// used by identify. | |||
/// </summary> | |||
/// <remarks> | |||
/// If the maximum concurrency is higher than one and you are using the sharded client, | |||
/// it will be dinamilly renamed to fit the necessary needs. | |||
/// </remarks> | |||
public string IdentifySemaphoreName { get; set; } | |||
/// <summary> | |||
/// Gets or sets the maximum identify concurrency. | |||
/// </summary> | |||
/// <remarks> | |||
/// This limit is provided by Discord. | |||
/// </remarks> | |||
public int IdentifyMaxConcurrency { get; set; } | |||
/// <summary> | |||
/// Initializes a new <see cref="GatewayLimits"/> with the default values. | |||
/// </summary> | |||
@@ -44,10 +75,26 @@ namespace Discord.Rest | |||
Global = new GatewayLimit(120, 60); | |||
Identify = new GatewayLimit(1, 5); | |||
PresenceUpdate = new GatewayLimit(5, 60); | |||
IdentifyMasterSemaphoreName = Guid.NewGuid().ToString(); | |||
IdentifySemaphoreName = Guid.NewGuid().ToString(); | |||
IdentifyMaxConcurrency = 1; | |||
} | |||
internal static GatewayLimits GetOrCreate(GatewayLimits limits) | |||
=> limits ?? new GatewayLimits(); | |||
internal GatewayLimits(GatewayLimits limits) | |||
{ | |||
Global = new GatewayLimit(limits.Global.Count, limits.Global.Seconds); | |||
Identify = new GatewayLimit(limits.Identify.Count, limits.Identify.Seconds); | |||
PresenceUpdate = new GatewayLimit(limits.PresenceUpdate.Count, limits.PresenceUpdate.Seconds); | |||
IdentifyMasterSemaphoreName = limits.IdentifyMasterSemaphoreName; | |||
IdentifySemaphoreName = limits.IdentifySemaphoreName; | |||
IdentifyMaxConcurrency = limits.IdentifyMaxConcurrency; | |||
} | |||
internal static GatewayLimits GetOrCreate(GatewayLimits? limits) | |||
=> limits ?? Default; | |||
public GatewayLimits Clone() | |||
=> new GatewayLimits(this); | |||
} | |||
} |
@@ -13,7 +13,6 @@ namespace Discord.Net.Queue | |||
{ | |||
private static ImmutableDictionary<GatewayBucketType, GatewayBucket> DefsByType; | |||
private static ImmutableDictionary<string, GatewayBucket> DefsById; | |||
private static string IdentifySemaphoreName; | |||
static GatewayBucket() | |||
{ | |||
@@ -22,7 +21,6 @@ namespace Discord.Net.Queue | |||
public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type]; | |||
public static GatewayBucket Get(string id) => DefsById[id]; | |||
public static string GetIdentifySemaphoreName() => IdentifySemaphoreName; | |||
public static void SetLimits(GatewayLimits limits) | |||
{ | |||
@@ -50,8 +48,6 @@ namespace Discord.Net.Queue | |||
foreach (var bucket in buckets) | |||
builder2.Add(bucket.Id, bucket); | |||
DefsById = builder2.ToImmutable(); | |||
IdentifySemaphoreName = limits.IdentifySemaphoreName; | |||
} | |||
public GatewayBucketType Type { get; } | |||
@@ -23,15 +23,16 @@ namespace Discord.Net.Queue | |||
private CancellationTokenSource _requestCancelTokenSource; | |||
private CancellationToken _requestCancelToken; //Parent token + Clear token | |||
private DateTimeOffset _waitUntil; | |||
private Semaphore _identifySemaphore; | |||
private readonly Semaphore _masterIdentifySemaphore; | |||
private readonly Semaphore _identifySemaphore; | |||
private readonly int _identifySemaphoreMaxConcurrency; | |||
private Task _cleanupTask; | |||
public RequestQueue() | |||
{ | |||
_tokenLock = new SemaphoreSlim(1, 1); | |||
int semaphoreCount = GatewayBucket.Get(GatewayBucketType.Identify).WindowCount; | |||
_identifySemaphore = new Semaphore(semaphoreCount, semaphoreCount, GatewayBucket.GetIdentifySemaphoreName()); | |||
_clearToken = new CancellationTokenSource(); | |||
_cancelTokenSource = new CancellationTokenSource(); | |||
@@ -43,6 +44,14 @@ namespace Discord.Net.Queue | |||
_cleanupTask = RunCleanup(); | |||
} | |||
public RequestQueue(string masterIdentifySemaphoreName, string slaveIdentifySemaphoreName, int slaveIdentifySemaphoreMaxConcurrency) | |||
: this () | |||
{ | |||
_masterIdentifySemaphore = new Semaphore(1, 1, masterIdentifySemaphoreName); | |||
_identifySemaphore = new Semaphore(0, GatewayBucket.Get(GatewayBucketType.Identify).WindowCount, slaveIdentifySemaphoreName); | |||
_identifySemaphoreMaxConcurrency = slaveIdentifySemaphoreMaxConcurrency; | |||
} | |||
public async Task SetCancelTokenAsync(CancellationToken cancelToken) | |||
{ | |||
await _tokenLock.WaitAsync().ConfigureAwait(false); | |||
@@ -132,8 +141,14 @@ namespace Discord.Net.Queue | |||
//Identify is per-account so we won't trigger global until we can actually go for it | |||
if (requestBucket.Type == GatewayBucketType.Identify) | |||
{ | |||
while (!_identifySemaphore.WaitOne(0)) //To not block the thread | |||
if (_masterIdentifySemaphore == null || _identifySemaphore == null) | |||
throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); | |||
bool master; | |||
while (!(master = _masterIdentifySemaphore.WaitOne(0)) && !_identifySemaphore.WaitOne(0)) //To not block the thread | |||
await Task.Delay(100, request.CancelToken); | |||
if (master && _identifySemaphoreMaxConcurrency > 1) | |||
_identifySemaphore.Release(_identifySemaphoreMaxConcurrency - 1); | |||
#if DEBUG_LIMITS | |||
Debug.WriteLine($"[{id}] Acquired identify ticket"); | |||
#endif | |||
@@ -149,7 +164,12 @@ namespace Discord.Net.Queue | |||
} | |||
internal void ReleaseIdentifySemaphore(int id) | |||
{ | |||
_identifySemaphore.Release(); | |||
if (_masterIdentifySemaphore == null || _identifySemaphore == null) | |||
throw new InvalidOperationException("Not a RequestQueue with WebSocket data."); | |||
while (_identifySemaphore.WaitOne(0)) //exhaust all tickets before releasing master | |||
{ } | |||
_masterIdentifySemaphore.Release(); | |||
#if DEBUG_LIMITS | |||
Debug.WriteLine($"[{id}] Released identify ticket"); | |||
#endif | |||
@@ -80,7 +80,7 @@ namespace Discord.WebSocket | |||
internal BaseSocketClient(DiscordSocketConfig config, DiscordRestApiClient client) | |||
: base(config, client) => BaseConfig = config; | |||
private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) | |||
=> new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, | |||
=> new DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayLimits, | |||
rateLimitPrecision: config.RateLimitPrecision, | |||
useSystemClock: config.UseSystemClock); | |||
@@ -81,29 +81,35 @@ namespace Discord.WebSocket | |||
_shardIdsToIndex.Add(_shardIds[i], i); | |||
var newConfig = config.Clone(); | |||
newConfig.ShardId = _shardIds[i]; | |||
if (config.GatewayLimits.IdentifyMaxConcurrency != 1) | |||
newConfig.GatewayLimits.IdentifySemaphoreName += $"_{i / config.GatewayLimits.IdentifyMaxConcurrency}"; | |||
_shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); | |||
RegisterEvents(_shards[i], i == 0); | |||
} | |||
} | |||
} | |||
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) | |||
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, | |||
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayLimits, | |||
rateLimitPrecision: config.RateLimitPrecision); | |||
internal override async Task OnLoginAsync(TokenType tokenType, string token) | |||
{ | |||
if (_automaticShards) | |||
{ | |||
var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false); | |||
_shardIds = Enumerable.Range(0, shardCount).ToArray(); | |||
var botGateway = await GetBotGatewayAsync().ConfigureAwait(false); | |||
_shardIds = Enumerable.Range(0, botGateway.Shards).ToArray(); | |||
_totalShards = _shardIds.Length; | |||
_shards = new DiscordSocketClient[_shardIds.Length]; | |||
int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency; | |||
_baseConfig.GatewayLimits.IdentifyMaxConcurrency = maxConcurrency; | |||
for (int i = 0; i < _shardIds.Length; i++) | |||
{ | |||
_shardIdsToIndex.Add(_shardIds[i], i); | |||
var newConfig = _baseConfig.Clone(); | |||
newConfig.ShardId = _shardIds[i]; | |||
newConfig.TotalShards = _totalShards; | |||
if (maxConcurrency != 1) | |||
newConfig.GatewayLimits.IdentifySemaphoreName += $"_{i / maxConcurrency}"; | |||
_shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null); | |||
RegisterEvents(_shards[i], i == 0); | |||
} | |||
@@ -3,6 +3,7 @@ using Discord.API.Gateway; | |||
using Discord.Net.Queue; | |||
using Discord.Net.Rest; | |||
using Discord.Net.WebSockets; | |||
using Discord.Rest; | |||
using Discord.WebSocket; | |||
using Newtonsoft.Json; | |||
using System; | |||
@@ -37,11 +38,11 @@ namespace Discord.API | |||
public ConnectionState ConnectionState { get; private set; } | |||
public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, | |||
public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, GatewayLimits limits, | |||
string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, | |||
RateLimitPrecision rateLimitPrecision = RateLimitPrecision.Second, | |||
bool useSystemClock = true) | |||
: base(restClientProvider, userAgent, defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) | |||
: base(restClientProvider, userAgent, new RequestQueue(limits.IdentifyMasterSemaphoreName, limits.IdentifySemaphoreName, limits.IdentifyMaxConcurrency), defaultRetryMode, serializer, rateLimitPrecision, useSystemClock) | |||
{ | |||
_gatewayUrl = url; | |||
if (url != null) | |||
@@ -182,7 +182,7 @@ namespace Discord.WebSocket | |||
_largeGuilds = new ConcurrentQueue<ulong>(); | |||
} | |||
private static API.DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config) | |||
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayHost, | |||
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent, config.GatewayLimits, config.GatewayHost, | |||
rateLimitPrecision: config.RateLimitPrecision); | |||
/// <inheritdoc /> | |||
internal override void Dispose(bool disposing) | |||
@@ -133,7 +133,7 @@ namespace Discord.WebSocket | |||
/// This property should only be changed for bots that have special limits provided by Discord. | |||
/// </note> | |||
/// </remarks> | |||
public GatewayLimits GatewayLimits { get; set; } = new GatewayLimits(); | |||
public GatewayLimits GatewayLimits { get; set; } = GatewayLimits.Default; | |||
/// <summary> | |||
/// Initializes a default configuration. | |||
@@ -144,6 +144,11 @@ namespace Discord.WebSocket | |||
UdpSocketProvider = DefaultUdpSocketProvider.Instance; | |||
} | |||
internal DiscordSocketConfig Clone() => MemberwiseClone() as DiscordSocketConfig; | |||
internal DiscordSocketConfig Clone() | |||
{ | |||
var clone = MemberwiseClone() as DiscordSocketConfig; | |||
clone.GatewayLimits = GatewayLimits.Clone(); | |||
return clone; | |||
} | |||
} | |||
} |