@@ -109,12 +109,81 @@ namespace Discord.Rest | |||||
public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, | public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, | ||||
ulong? fromMessageId, Direction dir, int limit, RequestOptions options) | ulong? fromMessageId, Direction dir, int limit, RequestOptions options) | ||||
{ | { | ||||
if (dir == Direction.Around) | |||||
throw new NotImplementedException(); //TODO: Impl | |||||
var guildId = (channel as IGuildChannel)?.GuildId; | var guildId = (channel as IGuildChannel)?.GuildId; | ||||
var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; | var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; | ||||
if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch) | |||||
{ | |||||
int around = limit / 2; | |||||
return new PagedAsyncEnumerable<RestMessage>( | |||||
DiscordConfig.MaxMessagesPerBatch, | |||||
async (info, ct) => | |||||
{ | |||||
var args = new GetChannelMessagesParams | |||||
{ | |||||
RelativeDirection = Direction.Before, | |||||
Limit = info.PageSize | |||||
}; | |||||
if (info.Position != null) | |||||
args.RelativeMessageId = info.Position.Value; | |||||
var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||||
var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||||
foreach (var model in models) | |||||
{ | |||||
var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||||
builder.Add(RestMessage.Create(client, channel, author, model)); | |||||
} | |||||
return builder.ToImmutable(); | |||||
}, | |||||
nextPage: (info, lastPage) => | |||||
{ | |||||
if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||||
return false; | |||||
if (dir == Direction.Before) | |||||
info.Position = lastPage.Min(x => x.Id); | |||||
else | |||||
info.Position = lastPage.Max(x => x.Id); | |||||
return true; | |||||
}, | |||||
start: fromMessageId + 1, //Needs to include the message itself | |||||
count: around + 1 | |||||
).Concat(new PagedAsyncEnumerable<RestMessage>( | |||||
DiscordConfig.MaxMessagesPerBatch, | |||||
async (info, ct) => | |||||
{ | |||||
var args = new GetChannelMessagesParams | |||||
{ | |||||
RelativeDirection = Direction.After, | |||||
Limit = info.PageSize | |||||
}; | |||||
if (info.Position != null) | |||||
args.RelativeMessageId = info.Position.Value; | |||||
var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||||
var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||||
foreach (var model in models) | |||||
{ | |||||
var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||||
builder.Add(RestMessage.Create(client, channel, author, model)); | |||||
} | |||||
return builder.ToImmutable(); | |||||
}, | |||||
nextPage: (info, lastPage) => | |||||
{ | |||||
if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||||
return false; | |||||
if (dir == Direction.Before) | |||||
info.Position = lastPage.Min(x => x.Id); | |||||
else | |||||
info.Position = lastPage.Max(x => x.Id); | |||||
return true; | |||||
}, | |||||
start: fromMessageId, | |||||
count: around | |||||
)); | |||||
} | |||||
return new PagedAsyncEnumerable<RestMessage>( | return new PagedAsyncEnumerable<RestMessage>( | ||||
DiscordConfig.MaxMessagesPerBatch, | DiscordConfig.MaxMessagesPerBatch, | ||||
async (info, ct) => | async (info, ct) => | ||||
@@ -11,23 +11,11 @@ namespace Discord.WebSocket | |||||
public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | ||||
ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) | ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) | ||||
{ | { | ||||
if (dir == Direction.Around) | |||||
throw new NotImplementedException(); //TODO: Impl | |||||
IReadOnlyCollection<SocketMessage> cachedMessages = null; | |||||
IAsyncEnumerable<IReadOnlyCollection<IMessage>> result = null; | |||||
if (dir == Direction.After && fromMessageId == null) | if (dir == Direction.After && fromMessageId == null) | ||||
return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>(); | return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>(); | ||||
if (dir == Direction.Before || mode == CacheMode.CacheOnly) | |||||
{ | |||||
if (messages != null) //Cache enabled | |||||
cachedMessages = messages.GetMany(fromMessageId, dir, limit); | |||||
else | |||||
cachedMessages = ImmutableArray.Create<SocketMessage>(); | |||||
result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||||
} | |||||
var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit); | |||||
var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||||
if (dir == Direction.Before) | if (dir == Direction.Before) | ||||
{ | { | ||||
@@ -38,18 +26,35 @@ namespace Discord.WebSocket | |||||
//Download remaining messages | //Download remaining messages | ||||
ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; | ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; | ||||
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); | var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); | ||||
return result.Concat(downloadedMessages); | |||||
if (cachedMessages.Count != 0) | |||||
return result.Concat(downloadedMessages); | |||||
else | |||||
return downloadedMessages; | |||||
} | } | ||||
else | |||||
else if (dir == Direction.After) | |||||
{ | |||||
limit -= cachedMessages.Count; | |||||
if (mode == CacheMode.CacheOnly || limit <= 0) | |||||
return result; | |||||
//Download remaining messages | |||||
ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value; | |||||
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options); | |||||
if (cachedMessages.Count != 0) | |||||
return result.Concat(downloadedMessages); | |||||
else | |||||
return downloadedMessages; | |||||
} | |||||
else //Direction.Around | |||||
{ | { | ||||
if (mode == CacheMode.CacheOnly) | |||||
if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count) | |||||
return result; | return result; | ||||
//Dont use cache in this case | |||||
//Cache isn't useful here since Discord will send them anyways | |||||
return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); | return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); | ||||
} | } | ||||
} | } | ||||
public static IReadOnlyCollection<SocketMessage> GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages, | |||||
public static IReadOnlyCollection<SocketMessage> GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | |||||
ulong? fromMessageId, Direction dir, int limit) | ulong? fromMessageId, Direction dir, int limit) | ||||
{ | { | ||||
if (messages != null) //Cache enabled | if (messages != null) //Cache enabled | ||||
@@ -56,11 +56,41 @@ namespace Discord.WebSocket | |||||
cachedMessageIds = _orderedMessages; | cachedMessageIds = _orderedMessages; | ||||
else if (dir == Direction.Before) | else if (dir == Direction.Before) | ||||
cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); | cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); | ||||
else | |||||
else if (dir == Direction.After) | |||||
cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); | cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); | ||||
else //Direction.Around | |||||
{ | |||||
if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg)) | |||||
return ImmutableArray<SocketMessage>.Empty; | |||||
int around = limit / 2; | |||||
var before = _orderedMessages | |||||
.Where(x => x < fromMessageId.Value) | |||||
.Select(x => | |||||
{ | |||||
if (_messages.TryGetValue(x, out SocketMessage msg)) | |||||
return msg; | |||||
return null; | |||||
}) | |||||
.Where(x => x != null) | |||||
.Take(around); | |||||
var after = _orderedMessages | |||||
.Where(x => x > fromMessageId.Value) | |||||
.Select(x => | |||||
{ | |||||
if (_messages.TryGetValue(x, out SocketMessage msg)) | |||||
return msg; | |||||
return null; | |||||
}) | |||||
.Where(x => x != null) | |||||
.Take(around); | |||||
return before.Concat(new SocketMessage[] { msg }).Concat(after).ToImmutableArray(); | |||||
} | |||||
if (dir == Direction.Before) | if (dir == Direction.Before) | ||||
cachedMessageIds = cachedMessageIds.Reverse(); | cachedMessageIds = cachedMessageIds.Reverse(); | ||||
if (dir == Direction.Around) | |||||
limit /= 2; | |||||
return cachedMessageIds | return cachedMessageIds | ||||
.Select(x => | .Select(x => | ||||