diff --git a/src/Discord.Net.Commands/Readers/UserTypeReader.cs b/src/Discord.Net.Commands/Readers/UserTypeReader.cs index d7fc6cfdc..c27945e39 100644 --- a/src/Discord.Net.Commands/Readers/UserTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/UserTypeReader.cs @@ -13,7 +13,7 @@ namespace Discord.Commands public override async Task Read(ICommandContext context, string input) { var results = new Dictionary(); - IReadOnlyCollection channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way? + IAsyncEnumerable channelUsers = context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten(); // it's better IReadOnlyCollection guildUsers = ImmutableArray.Create(); ulong id; @@ -46,7 +46,7 @@ namespace Discord.Commands ushort discriminator; if (ushort.TryParse(input.Substring(index + 1), out discriminator)) { - var channelUser = channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator && + var channelUser = await channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator && string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)); AddResult(results, channelUser as T, channelUser?.Username == username ? 0.85f : 0.75f); @@ -58,8 +58,9 @@ namespace Discord.Commands //By Username (0.5-0.6) { - foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) - AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f); + await channelUsers + .Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)) + .ForEachAsync(channelUser => AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f)); foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f); @@ -67,8 +68,9 @@ namespace Discord.Commands //By Nickname (0.5-0.6) { - foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase))) - AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f); + await channelUsers + .Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase)) + .ForEachAsync(channelUser => AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f)); foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f); diff --git a/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs b/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs index f52edd719..345154f1d 100644 --- a/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs +++ b/src/Discord.Net.Core/Extensions/AsyncEnumerableExtensions.cs @@ -1,14 +1,64 @@ using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace Discord { public static class AsyncEnumerableExtensions { - public static async Task> Flatten(this IAsyncEnumerable> source) + /// + /// Flattens the specified pages into one asynchronously + /// + /// + /// + /// + public static async Task> FlattenAsync(this IAsyncEnumerable> source) { - return (await source.ToArray().ConfigureAwait(false)).SelectMany(x => x); + return await source.Flatten().ToArray().ConfigureAwait(false); + } + + public static IAsyncEnumerable Flatten(this IAsyncEnumerable> source) + { + return new PagedCollectionEnumerator(source); + } + + internal class PagedCollectionEnumerator : IAsyncEnumerator, IAsyncEnumerable + { + readonly IAsyncEnumerator> _source; + IEnumerator _enumerator; + + public IAsyncEnumerator GetEnumerator() => this; + + internal PagedCollectionEnumerator(IAsyncEnumerable> source) + { + _source = source.GetEnumerator(); + } + + public T Current => _enumerator.Current; + + public void Dispose() + { + _enumerator?.Dispose(); + _source.Dispose(); + } + + public async Task MoveNext(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + if(!_enumerator?.MoveNext() ?? true) + { + if (!await _source.MoveNext(cancellationToken).ConfigureAwait(false)) + return false; + + _enumerator?.Dispose(); + _enumerator = _source.Current.GetEnumerator(); + return _enumerator.MoveNext(); + } + + return true; + } } } } diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs index 8bc800a7d..a4a048a05 100644 --- a/src/Discord.Net.Rest/ClientHelper.cs +++ b/src/Discord.Net.Rest/ClientHelper.cs @@ -79,7 +79,7 @@ namespace Discord.Rest ulong? fromGuildId, int? limit, RequestOptions options) { return new PagedAsyncEnumerable( - DiscordConfig.MaxUsersPerBatch, + DiscordConfig.MaxGuildsPerBatch, async (info, ct) => { var args = new GetGuildSummariesParams @@ -106,7 +106,7 @@ namespace Discord.Rest } public static async Task> GetGuildsAsync(BaseDiscordClient client, RequestOptions options) { - var summaryModels = await GetGuildSummariesAsync(client, null, null, options).Flatten(); + var summaryModels = await GetGuildSummariesAsync(client, null, null, options).FlattenAsync().ConfigureAwait(false); var guilds = ImmutableArray.CreateBuilder(); foreach (var summaryModel in summaryModels) { diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs index 8b5598ffe..c8ec9fb73 100644 --- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs +++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs @@ -356,7 +356,7 @@ namespace Discord.Rest async Task> IGuild.GetUsersAsync(CacheMode mode, RequestOptions options) { if (mode == CacheMode.AllowDownload) - return (await GetUsersAsync(options).Flatten().ConfigureAwait(false)).ToImmutableArray(); + return (await GetUsersAsync(options).FlattenAsync().ConfigureAwait(false)).ToImmutableArray(); else return ImmutableArray.Create(); }