@@ -109,12 +109,81 @@ namespace Discord.Rest | |||
public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client, | |||
ulong? fromMessageId, Direction dir, int limit, RequestOptions options) | |||
{ | |||
if (dir == Direction.Around) | |||
throw new NotImplementedException(); //TODO: Impl | |||
var guildId = (channel as IGuildChannel)?.GuildId; | |||
var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null; | |||
if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch) | |||
{ | |||
int around = limit / 2; | |||
return new PagedAsyncEnumerable<RestMessage>( | |||
DiscordConfig.MaxMessagesPerBatch, | |||
async (info, ct) => | |||
{ | |||
var args = new GetChannelMessagesParams | |||
{ | |||
RelativeDirection = Direction.Before, | |||
Limit = info.PageSize | |||
}; | |||
if (info.Position != null) | |||
args.RelativeMessageId = info.Position.Value; | |||
var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||
var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||
foreach (var model in models) | |||
{ | |||
var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||
builder.Add(RestMessage.Create(client, channel, author, model)); | |||
} | |||
return builder.ToImmutable(); | |||
}, | |||
nextPage: (info, lastPage) => | |||
{ | |||
if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||
return false; | |||
if (dir == Direction.Before) | |||
info.Position = lastPage.Min(x => x.Id); | |||
else | |||
info.Position = lastPage.Max(x => x.Id); | |||
return true; | |||
}, | |||
start: fromMessageId + 1, //Needs to include the message itself | |||
count: around + 1 | |||
).Concat(new PagedAsyncEnumerable<RestMessage>( | |||
DiscordConfig.MaxMessagesPerBatch, | |||
async (info, ct) => | |||
{ | |||
var args = new GetChannelMessagesParams | |||
{ | |||
RelativeDirection = Direction.After, | |||
Limit = info.PageSize | |||
}; | |||
if (info.Position != null) | |||
args.RelativeMessageId = info.Position.Value; | |||
var models = await client.ApiClient.GetChannelMessagesAsync(channel.Id, args, options).ConfigureAwait(false); | |||
var builder = ImmutableArray.CreateBuilder<RestMessage>(); | |||
foreach (var model in models) | |||
{ | |||
var author = GetAuthor(client, guild, model.Author.Value, model.WebhookId.ToNullable()); | |||
builder.Add(RestMessage.Create(client, channel, author, model)); | |||
} | |||
return builder.ToImmutable(); | |||
}, | |||
nextPage: (info, lastPage) => | |||
{ | |||
if (lastPage.Count != DiscordConfig.MaxMessagesPerBatch) | |||
return false; | |||
if (dir == Direction.Before) | |||
info.Position = lastPage.Min(x => x.Id); | |||
else | |||
info.Position = lastPage.Max(x => x.Id); | |||
return true; | |||
}, | |||
start: fromMessageId, | |||
count: around | |||
)); | |||
} | |||
return new PagedAsyncEnumerable<RestMessage>( | |||
DiscordConfig.MaxMessagesPerBatch, | |||
async (info, ct) => | |||
@@ -11,23 +11,11 @@ namespace Discord.WebSocket | |||
public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | |||
ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options) | |||
{ | |||
if (dir == Direction.Around) | |||
throw new NotImplementedException(); //TODO: Impl | |||
IReadOnlyCollection<SocketMessage> cachedMessages = null; | |||
IAsyncEnumerable<IReadOnlyCollection<IMessage>> result = null; | |||
if (dir == Direction.After && fromMessageId == null) | |||
return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>(); | |||
if (dir == Direction.Before || mode == CacheMode.CacheOnly) | |||
{ | |||
if (messages != null) //Cache enabled | |||
cachedMessages = messages.GetMany(fromMessageId, dir, limit); | |||
else | |||
cachedMessages = ImmutableArray.Create<SocketMessage>(); | |||
result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||
} | |||
var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit); | |||
var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>(); | |||
if (dir == Direction.Before) | |||
{ | |||
@@ -38,18 +26,35 @@ namespace Discord.WebSocket | |||
//Download remaining messages | |||
ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId; | |||
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options); | |||
return result.Concat(downloadedMessages); | |||
if (cachedMessages.Count != 0) | |||
return result.Concat(downloadedMessages); | |||
else | |||
return downloadedMessages; | |||
} | |||
else | |||
else if (dir == Direction.After) | |||
{ | |||
limit -= cachedMessages.Count; | |||
if (mode == CacheMode.CacheOnly || limit <= 0) | |||
return result; | |||
//Download remaining messages | |||
ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value; | |||
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options); | |||
if (cachedMessages.Count != 0) | |||
return result.Concat(downloadedMessages); | |||
else | |||
return downloadedMessages; | |||
} | |||
else //Direction.Around | |||
{ | |||
if (mode == CacheMode.CacheOnly) | |||
if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count) | |||
return result; | |||
//Dont use cache in this case | |||
//Cache isn't useful here since Discord will send them anyways | |||
return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options); | |||
} | |||
} | |||
public static IReadOnlyCollection<SocketMessage> GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages, | |||
public static IReadOnlyCollection<SocketMessage> GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages, | |||
ulong? fromMessageId, Direction dir, int limit) | |||
{ | |||
if (messages != null) //Cache enabled | |||
@@ -56,11 +56,41 @@ namespace Discord.WebSocket | |||
cachedMessageIds = _orderedMessages; | |||
else if (dir == Direction.Before) | |||
cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value); | |||
else | |||
else if (dir == Direction.After) | |||
cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value); | |||
else //Direction.Around | |||
{ | |||
if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg)) | |||
return ImmutableArray<SocketMessage>.Empty; | |||
int around = limit / 2; | |||
var before = _orderedMessages | |||
.Where(x => x < fromMessageId.Value) | |||
.Select(x => | |||
{ | |||
if (_messages.TryGetValue(x, out SocketMessage msg)) | |||
return msg; | |||
return null; | |||
}) | |||
.Where(x => x != null) | |||
.Take(around); | |||
var after = _orderedMessages | |||
.Where(x => x > fromMessageId.Value) | |||
.Select(x => | |||
{ | |||
if (_messages.TryGetValue(x, out SocketMessage msg)) | |||
return msg; | |||
return null; | |||
}) | |||
.Where(x => x != null) | |||
.Take(around); | |||
return before.Concat(new SocketMessage[] { msg }).Concat(after).ToImmutableArray(); | |||
} | |||
if (dir == Direction.Before) | |||
cachedMessageIds = cachedMessageIds.Reverse(); | |||
if (dir == Direction.Around) | |||
limit /= 2; | |||
return cachedMessageIds | |||
.Select(x => | |||