@@ -1,7 +1,9 @@ | |||||
using System; | using System; | ||||
using System.Collections.Concurrent; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Collections.Immutable; | using System.Collections.Immutable; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | |||||
using System.Reflection; | using System.Reflection; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
@@ -10,6 +12,9 @@ namespace Discord.Commands | |||||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | [DebuggerDisplay(@"{DebuggerDisplay,nq}")] | ||||
public class Command | public class Command | ||||
{ | { | ||||
private static readonly MethodInfo _convertParamsMethod = typeof(Command).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); | |||||
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>(); | |||||
private readonly object _instance; | private readonly object _instance; | ||||
private readonly Func<IMessage, IReadOnlyList<object>, Task> _action; | private readonly Func<IMessage, IReadOnlyList<object>, Task> _action; | ||||
@@ -19,6 +24,7 @@ namespace Discord.Commands | |||||
public string Description { get; } | public string Description { get; } | ||||
public string Summary { get; } | public string Summary { get; } | ||||
public string Text { get; } | public string Text { get; } | ||||
public bool HasVarArgs { get; } | |||||
public IReadOnlyList<CommandParameter> Parameters { get; } | public IReadOnlyList<CommandParameter> Parameters { get; } | ||||
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | ||||
@@ -42,8 +48,9 @@ namespace Discord.Commands | |||||
var summary = source.GetCustomAttribute<SummaryAttribute>(); | var summary = source.GetCustomAttribute<SummaryAttribute>(); | ||||
if (summary != null) | if (summary != null) | ||||
Summary = summary.Text; | Summary = summary.Text; | ||||
Parameters = BuildParameters(source); | Parameters = BuildParameters(source); | ||||
HasVarArgs = Parameters.Count > 0 ? Parameters[Parameters.Count - 1].IsMultiple : false; | |||||
Preconditions = BuildPreconditions(source); | Preconditions = BuildPreconditions(source); | ||||
_action = BuildAction(source); | _action = BuildAction(source); | ||||
} | } | ||||
@@ -76,14 +83,38 @@ namespace Discord.Commands | |||||
return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false); | return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false); | ||||
} | } | ||||
public async Task<ExecuteResult> Execute(IMessage msg, ParseResult parseResult) | |||||
public Task<ExecuteResult> Execute(IMessage msg, ParseResult parseResult) | |||||
{ | { | ||||
if (!parseResult.IsSuccess) | if (!parseResult.IsSuccess) | ||||
return ExecuteResult.FromError(parseResult); | |||||
return Task.FromResult(ExecuteResult.FromError(parseResult)); | |||||
var argList = new object[parseResult.ArgValues.Count]; | |||||
for (int i = 0; i < parseResult.ArgValues.Count; i++) | |||||
{ | |||||
if (!parseResult.ArgValues[i].IsSuccess) | |||||
return Task.FromResult(ExecuteResult.FromError(parseResult.ArgValues[i])); | |||||
argList[i] = parseResult.ArgValues[i].Values.First().Value; | |||||
} | |||||
object[] paramList = null; | |||||
if (parseResult.ParamValues != null) | |||||
{ | |||||
paramList = new object[parseResult.ParamValues.Count]; | |||||
for (int i = 0; i < parseResult.ParamValues.Count; i++) | |||||
{ | |||||
if (!parseResult.ParamValues[i].IsSuccess) | |||||
return Task.FromResult(ExecuteResult.FromError(parseResult.ParamValues[i])); | |||||
paramList[i] = parseResult.ParamValues[i].Values.First().Value; | |||||
} | |||||
} | |||||
return Execute(msg, argList, paramList); | |||||
} | |||||
public async Task<ExecuteResult> Execute(IMessage msg, IEnumerable<object> argList, IEnumerable<object> paramList) | |||||
{ | |||||
try | try | ||||
{ | { | ||||
await _action.Invoke(msg, parseResult.Values);//Note: This code may need context | |||||
await _action.Invoke(msg, GenerateArgs(argList, paramList)).ConfigureAwait(false);//Note: This code may need context | |||||
return ExecuteResult.FromSuccess(); | return ExecuteResult.FromSuccess(); | ||||
} | } | ||||
catch (Exception ex) | catch (Exception ex) | ||||
@@ -108,7 +139,7 @@ namespace Discord.Commands | |||||
{ | { | ||||
var parameter = parameters[i]; | var parameter = parameters[i]; | ||||
var type = parameter.ParameterType; | var type = parameter.ParameterType; | ||||
//Detect 'params' | //Detect 'params' | ||||
bool isMultiple = parameter.GetCustomAttribute<ParamArrayAttribute>() != null; | bool isMultiple = parameter.GetCustomAttribute<ParamArrayAttribute>() != null; | ||||
if (isMultiple) | if (isMultiple) | ||||
@@ -156,6 +187,39 @@ namespace Discord.Commands | |||||
}; | }; | ||||
} | } | ||||
private object[] GenerateArgs(IEnumerable<object> argList, IEnumerable<object> paramsList) | |||||
{ | |||||
int argCount = Parameters.Count; | |||||
var array = new object[Parameters.Count]; | |||||
if (HasVarArgs) | |||||
argCount--; | |||||
int i = 0; | |||||
foreach (var arg in argList) | |||||
{ | |||||
if (i == argCount) | |||||
throw new InvalidOperationException("Command was invoked with too many parameters"); | |||||
array[i++] = arg; | |||||
} | |||||
if (i < argCount) | |||||
throw new InvalidOperationException("Command was invoked with too few parameters"); | |||||
if (HasVarArgs) | |||||
{ | |||||
var func = _arrayConverters.GetOrAdd(Parameters[Parameters.Count - 1].ElementType, t => | |||||
{ | |||||
var method = _convertParamsMethod.MakeGenericMethod(t); | |||||
return (Func<IEnumerable<object>, object>)method.CreateDelegate(typeof(Func<IEnumerable<object>, object>)); | |||||
}); | |||||
array[i] = func(paramsList); | |||||
} | |||||
return array; | |||||
} | |||||
private static T[] ConvertParamsList<T>(IEnumerable<object> paramsList) | |||||
=> paramsList.Cast<T>().ToArray(); | |||||
public override string ToString() => Name; | public override string ToString() => Name; | ||||
private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})"; | private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})"; | ||||
} | } | ||||
@@ -3,14 +3,14 @@ | |||||
public enum CommandError | public enum CommandError | ||||
{ | { | ||||
//Search | //Search | ||||
UnknownCommand, | |||||
UnknownCommand = 1, | |||||
//Parse | //Parse | ||||
ParseFailed, | ParseFailed, | ||||
BadArgCount, | BadArgCount, | ||||
//Parse (Type Reader) | //Parse (Type Reader) | ||||
CastFailed, | |||||
//CastFailed, | |||||
ObjectNotFound, | ObjectNotFound, | ||||
MultipleMatches, | MultipleMatches, | ||||
@@ -17,7 +17,7 @@ namespace Discord.Commands | |||||
public bool IsRemainder { get; } | public bool IsRemainder { get; } | ||||
public bool IsMultiple { get; } | public bool IsMultiple { get; } | ||||
public Type ElementType { get; } | public Type ElementType { get; } | ||||
internal object DefaultValue { get; } | |||||
public object DefaultValue { get; } | |||||
public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue) | public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue) | ||||
{ | { | ||||
@@ -1,8 +1,5 @@ | |||||
using System; | |||||
using System.Collections.Concurrent; | |||||
using System.Collections.Generic; | |||||
| |||||
using System.Collections.Immutable; | using System.Collections.Immutable; | ||||
using System.Reflection; | |||||
using System.Text; | using System.Text; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
@@ -16,9 +13,6 @@ namespace Discord.Commands | |||||
Parameter, | Parameter, | ||||
QuotedParameter | QuotedParameter | ||||
} | } | ||||
private static readonly MethodInfo _convertArrayMethod = typeof(CommandParser).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); | |||||
private static readonly ConcurrentDictionary<Type, Func<List<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<List<object>, object>>(); | |||||
public static async Task<ParseResult> ParseArgs(Command command, IMessage context, string input, int startPos) | public static async Task<ParseResult> ParseArgs(Command command, IMessage context, string input, int startPos) | ||||
{ | { | ||||
@@ -27,9 +21,10 @@ namespace Discord.Commands | |||||
int endPos = input.Length; | int endPos = input.Length; | ||||
var curPart = ParserPart.None; | var curPart = ParserPart.None; | ||||
int lastArgEndPos = int.MinValue; | int lastArgEndPos = int.MinValue; | ||||
var argList = ImmutableArray.CreateBuilder<object>(); | |||||
List<object> paramsList = null; // TODO: could we use a better type? | |||||
var argList = ImmutableArray.CreateBuilder<TypeReaderResult>(); | |||||
ImmutableArray<TypeReaderResult>.Builder paramList = null; | |||||
bool isEscaping = false; | bool isEscaping = false; | ||||
bool hasMultipleMatches = false; | |||||
char c; | char c; | ||||
for (int curPos = startPos; curPos <= endPos; curPos++) | for (int curPos = startPos; curPos <= endPos; curPos++) | ||||
@@ -117,30 +112,28 @@ namespace Discord.Commands | |||||
var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false); | var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false); | ||||
if (!typeReaderResult.IsSuccess) | if (!typeReaderResult.IsSuccess) | ||||
return ParseResult.FromError(typeReaderResult); | |||||
{ | |||||
if (typeReaderResult.Error == CommandError.MultipleMatches) | |||||
hasMultipleMatches = true; | |||||
else | |||||
return ParseResult.FromError(typeReaderResult); | |||||
} | |||||
if (curParam.IsMultiple) | if (curParam.IsMultiple) | ||||
{ | { | ||||
if (paramsList == null) | |||||
paramsList = new List<object>(); | |||||
paramsList.Add(typeReaderResult.Value); | |||||
if (paramList == null) | |||||
paramList = ImmutableArray.CreateBuilder<TypeReaderResult>(); | |||||
paramList.Add(typeReaderResult); | |||||
if (curPos == endPos) | if (curPos == endPos) | ||||
{ | { | ||||
var func = _arrayConverters.GetOrAdd(curParam.ElementType, t => | |||||
{ | |||||
var method = _convertArrayMethod.MakeGenericMethod(t); | |||||
return (Func<List<object>, object>)method.CreateDelegate(typeof(Func<List<object>, object>)); | |||||
}); | |||||
argList.Add(func.Invoke(paramsList)); | |||||
curParam = null; | curParam = null; | ||||
curPart = ParserPart.None; | curPart = ParserPart.None; | ||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
argList.Add(typeReaderResult.Value); | |||||
argList.Add(typeReaderResult); | |||||
curParam = null; | curParam = null; | ||||
curPart = ParserPart.None; | curPart = ParserPart.None; | ||||
@@ -154,34 +147,24 @@ namespace Discord.Commands | |||||
var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false); | var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false); | ||||
if (!typeReaderResult.IsSuccess) | if (!typeReaderResult.IsSuccess) | ||||
return ParseResult.FromError(typeReaderResult); | return ParseResult.FromError(typeReaderResult); | ||||
argList.Add(typeReaderResult.Value); | |||||
argList.Add(typeReaderResult); | |||||
} | } | ||||
if (isEscaping) | if (isEscaping) | ||||
return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape."); | return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape."); | ||||
if (curPart == ParserPart.QuotedParameter) | if (curPart == ParserPart.QuotedParameter) | ||||
return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete"); | return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete"); | ||||
if (argList.Count < command.Parameters.Count) | |||||
//Add missing optionals | |||||
for (int i = paramList.Count; i < command.Parameters.Count; i++) | |||||
{ | { | ||||
for (int i = argList.Count; i < command.Parameters.Count; i++) | |||||
{ | |||||
var param = command.Parameters[i]; | |||||
if (!param.IsOptional) | |||||
return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters."); | |||||
argList.Add(param.DefaultValue); | |||||
} | |||||
var param = command.Parameters[i]; | |||||
if (!param.IsOptional) | |||||
return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters."); | |||||
argList.Add(TypeReaderResult.FromSuccess(param.DefaultValue)); | |||||
} | } | ||||
return ParseResult.FromSuccess(argList.ToImmutable()); | |||||
} | |||||
private static T[] ConvertParamsList<T>(List<object> paramsList) | |||||
{ | |||||
var array = new T[paramsList.Count]; | |||||
for (int i = 0; i < array.Length; i++) | |||||
array[i] = (T)paramsList[i]; | |||||
return array; | |||||
return ParseResult.FromSuccess(argList.ToImmutable(), paramList?.ToImmutable()); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -40,16 +40,8 @@ namespace Discord.Commands | |||||
[typeof(decimal)] = new SimpleTypeReader<decimal>(), | [typeof(decimal)] = new SimpleTypeReader<decimal>(), | ||||
[typeof(DateTime)] = new SimpleTypeReader<DateTime>(), | [typeof(DateTime)] = new SimpleTypeReader<DateTime>(), | ||||
[typeof(DateTimeOffset)] = new SimpleTypeReader<DateTimeOffset>(), | [typeof(DateTimeOffset)] = new SimpleTypeReader<DateTimeOffset>(), | ||||
//TODO: Do we want to support any other interfaces? | |||||
//[typeof(IMentionable)] = new GeneralTypeReader(), | |||||
//[typeof(ISnowflakeEntity)] = new GeneralTypeReader(), | |||||
//[typeof(IEntity<ulong>)] = new GeneralTypeReader(), | |||||
[typeof(IMessage)] = new MessageTypeReader(), | [typeof(IMessage)] = new MessageTypeReader(), | ||||
//[typeof(IAttachment)] = new xxx(), | |||||
//[typeof(IEmbed)] = new xxx(), | |||||
[typeof(IChannel)] = new ChannelTypeReader<IChannel>(), | [typeof(IChannel)] = new ChannelTypeReader<IChannel>(), | ||||
[typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(), | [typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(), | ||||
@@ -61,10 +53,8 @@ namespace Discord.Commands | |||||
[typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(), | [typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(), | ||||
//[typeof(IGuild)] = new GuildTypeReader<IGuild>(), | //[typeof(IGuild)] = new GuildTypeReader<IGuild>(), | ||||
//[typeof(IUserGuild)] = new GuildTypeReader<IUserGuild>(), | |||||
//[typeof(IGuildIntegration)] = new xxx(), | |||||
[typeof(IRole)] = new RoleTypeReader(), | |||||
[typeof(IRole)] = new RoleTypeReader<IRole>(), | |||||
//[typeof(IInvite)] = new InviteTypeReader<IInvite>(), | //[typeof(IInvite)] = new InviteTypeReader<IInvite>(), | ||||
//[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(), | //[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(), | ||||
@@ -72,10 +62,6 @@ namespace Discord.Commands | |||||
[typeof(IUser)] = new UserTypeReader<IUser>(), | [typeof(IUser)] = new UserTypeReader<IUser>(), | ||||
[typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(), | [typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(), | ||||
[typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(), | [typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(), | ||||
//[typeof(ISelfUser)] = new UserTypeReader<ISelfUser>(), | |||||
//[typeof(IPresence)] = new UserTypeReader<IPresence>(), | |||||
//[typeof(IVoiceState)] = new UserTypeReader<IVoiceState>(), | |||||
//[typeof(IConnection)] = new xxx(), | |||||
}; | }; | ||||
} | } | ||||
@@ -201,8 +187,9 @@ namespace Discord.Commands | |||||
return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); | return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); | ||||
} | } | ||||
public Task<IResult> Execute(IMessage message, int argPos) => Execute(message, message.Content.Substring(argPos)); | |||||
public async Task<IResult> Execute(IMessage message, string input) | |||||
public Task<IResult> Execute(IMessage message, int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) | |||||
=> Execute(message, message.Content.Substring(argPos), multiMatchHandling); | |||||
public async Task<IResult> Execute(IMessage message, string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) | |||||
{ | { | ||||
var searchResult = Search(message, input); | var searchResult = Search(message, input); | ||||
if (!searchResult.IsSuccess) | if (!searchResult.IsSuccess) | ||||
@@ -223,14 +210,29 @@ namespace Discord.Commands | |||||
var parseResult = await commands[i].Parse(message, searchResult, preconditionResult); | var parseResult = await commands[i].Parse(message, searchResult, preconditionResult); | ||||
if (!parseResult.IsSuccess) | if (!parseResult.IsSuccess) | ||||
{ | { | ||||
if (commands.Count == 1) | |||||
return parseResult; | |||||
else | |||||
continue; | |||||
if (parseResult.Error == CommandError.MultipleMatches) | |||||
{ | |||||
TypeReaderValue[] argList, paramList; | |||||
switch (multiMatchHandling) | |||||
{ | |||||
case MultiMatchHandling.Best: | |||||
argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray(); | |||||
paramList = parseResult.ParamValues?.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray(); | |||||
parseResult = ParseResult.FromSuccess(argList, paramList); | |||||
break; | |||||
} | |||||
} | |||||
if (!parseResult.IsSuccess) | |||||
{ | |||||
if (commands.Count == 1) | |||||
return parseResult; | |||||
else | |||||
continue; | |||||
} | |||||
} | } | ||||
var executeResult = await commands[i].Execute(message, parseResult); | |||||
return executeResult; | |||||
return await commands[i].Execute(message, parseResult); | |||||
} | } | ||||
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); | return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); | ||||
@@ -0,0 +1,8 @@ | |||||
namespace Discord.Commands | |||||
{ | |||||
public enum MultiMatchHandling | |||||
{ | |||||
Exception, | |||||
Best | |||||
} | |||||
} |
@@ -1,4 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Globalization; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
@@ -9,40 +11,37 @@ namespace Discord.Commands | |||||
{ | { | ||||
public override async Task<TypeReaderResult> Read(IMessage context, string input) | public override async Task<TypeReaderResult> Read(IMessage context, string input) | ||||
{ | { | ||||
IGuildChannel guildChannel = context.Channel as IGuildChannel; | |||||
IChannel result = null; | |||||
var guild = (context.Channel as IGuildChannel)?.Guild; | |||||
if (guildChannel != null) | |||||
if (guild != null) | |||||
{ | { | ||||
//By Id | |||||
var results = new Dictionary<ulong, TypeReaderValue>(); | |||||
var channels = await guild.GetChannelsAsync().ConfigureAwait(false); | |||||
ulong id; | ulong id; | ||||
if (MentionUtils.TryParseChannel(input, out id) || ulong.TryParse(input, out id)) | |||||
{ | |||||
var channel = await guildChannel.Guild.GetChannelAsync(id).ConfigureAwait(false); | |||||
if (channel != null) | |||||
result = channel; | |||||
} | |||||
//By Name | |||||
if (result == null) | |||||
{ | |||||
var channels = await guildChannel.Guild.GetChannelsAsync().ConfigureAwait(false); | |||||
var filteredChannels = channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray(); | |||||
if (filteredChannels.Length > 1) | |||||
return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple channels found."); | |||||
else if (filteredChannels.Length == 1) | |||||
result = filteredChannels[0]; | |||||
} | |||||
//By Mention (1.0) | |||||
if (MentionUtils.TryParseChannel(input, out id)) | |||||
AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f); | |||||
//By Id (0.9) | |||||
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) | |||||
AddResult(results, await guild.GetChannelAsync(id).ConfigureAwait(false) as T, 0.90f); | |||||
//By Name (0.7-0.8) | |||||
foreach (var channel in channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, channel as T, channel.Name == input ? 0.80f : 0.70f); | |||||
if (results.Count > 0) | |||||
return TypeReaderResult.FromSuccess(results.Values); | |||||
} | } | ||||
if (result == null) | |||||
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found."); | |||||
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found."); | |||||
} | |||||
T castResult = result as T; | |||||
if (castResult == null) | |||||
return TypeReaderResult.FromError(CommandError.CastFailed, $"Channel is not a {typeof(T).Name}."); | |||||
else | |||||
return TypeReaderResult.FromSuccess(castResult); | |||||
private void AddResult(Dictionary<ulong, TypeReaderValue> results, T channel, float score) | |||||
{ | |||||
if (channel != null && !results.ContainsKey(channel.Id)) | |||||
results.Add(channel.Id, new TypeReaderValue(channel, score)); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -52,14 +52,14 @@ namespace Discord.Commands | |||||
if (_enumsByValue.TryGetValue(baseValue, out enumValue)) | if (_enumsByValue.TryGetValue(baseValue, out enumValue)) | ||||
return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); | return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); | ||||
else | else | ||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}")); | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}")); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
if (_enumsByName.TryGetValue(input.ToLower(), out enumValue)) | if (_enumsByName.TryGetValue(input.ToLower(), out enumValue)) | ||||
return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); | return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); | ||||
else | else | ||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}")); | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}")); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -7,18 +7,17 @@ namespace Discord.Commands | |||||
{ | { | ||||
public override Task<TypeReaderResult> Read(IMessage context, string input) | public override Task<TypeReaderResult> Read(IMessage context, string input) | ||||
{ | { | ||||
//By Id | |||||
ulong id; | ulong id; | ||||
//By Id (1.0) | |||||
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) | if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) | ||||
{ | { | ||||
var msg = context.Channel.GetCachedMessage(id); | var msg = context.Channel.GetCachedMessage(id); | ||||
if (msg == null) | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found.")); | |||||
else | |||||
if (msg != null) | |||||
return Task.FromResult(TypeReaderResult.FromSuccess(msg)); | return Task.FromResult(TypeReaderResult.FromSuccess(msg)); | ||||
} | } | ||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Message Id.")); | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found.")); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,36 +1,46 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Globalization; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
namespace Discord.Commands | namespace Discord.Commands | ||||
{ | { | ||||
internal class RoleTypeReader : TypeReader | |||||
internal class RoleTypeReader<T> : TypeReader | |||||
where T : class, IRole | |||||
{ | { | ||||
public override Task<TypeReaderResult> Read(IMessage context, string input) | public override Task<TypeReaderResult> Read(IMessage context, string input) | ||||
{ | { | ||||
IGuildChannel guildChannel = context.Channel as IGuildChannel; | |||||
var guild = (context.Channel as IGuildChannel)?.Guild; | |||||
ulong id; | |||||
if (guildChannel != null) | |||||
if (guild != null) | |||||
{ | { | ||||
//By Id | |||||
ulong id; | |||||
if (MentionUtils.TryParseRole(input, out id) || ulong.TryParse(input, out id)) | |||||
{ | |||||
var channel = guildChannel.Guild.GetRole(id); | |||||
if (channel != null) | |||||
return Task.FromResult(TypeReaderResult.FromSuccess(channel)); | |||||
} | |||||
var results = new Dictionary<ulong, TypeReaderValue>(); | |||||
var roles = guild.Roles; | |||||
//By Name | |||||
var roles = guildChannel.Guild.Roles; | |||||
var filteredRoles = roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray(); | |||||
if (filteredRoles.Length > 1) | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple roles found.")); | |||||
else if (filteredRoles.Length == 1) | |||||
return Task.FromResult(TypeReaderResult.FromSuccess(filteredRoles[0])); | |||||
//By Mention (1.0) | |||||
if (MentionUtils.TryParseRole(input, out id)) | |||||
AddResult(results, guild.GetRole(id) as T, 1.00f); | |||||
//By Id (0.9) | |||||
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) | |||||
AddResult(results, guild.GetRole(id) as T, 0.90f); | |||||
//By Name (0.7-0.8) | |||||
foreach (var role in roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, role as T, role.Name == input ? 0.80f : 0.70f); | |||||
if (results.Count > 0) | |||||
return Task.FromResult(TypeReaderResult.FromSuccess(results)); | |||||
} | } | ||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found.")); | return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found.")); | ||||
} | } | ||||
private void AddResult(Dictionary<ulong, TypeReaderValue> results, T role, float score) | |||||
{ | |||||
if (role != null && !results.ContainsKey(role.Id)) | |||||
results.Add(role.Id, new TypeReaderValue(role, score)); | |||||
} | |||||
} | } | ||||
} | } |
@@ -16,8 +16,7 @@ namespace Discord.Commands | |||||
T value; | T value; | ||||
if (_tryParse(input, out value)) | if (_tryParse(input, out value)) | ||||
return Task.FromResult(TypeReaderResult.FromSuccess(value)); | return Task.FromResult(TypeReaderResult.FromSuccess(value)); | ||||
else | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}")); | |||||
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}")); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,4 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Globalization; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
@@ -9,54 +11,78 @@ namespace Discord.Commands | |||||
{ | { | ||||
public override async Task<TypeReaderResult> Read(IMessage context, string input) | public override async Task<TypeReaderResult> Read(IMessage context, string input) | ||||
{ | { | ||||
IUser result = null; | |||||
//By Id | |||||
var results = new Dictionary<ulong, TypeReaderValue>(); | |||||
var guild = (context.Channel as IGuildChannel)?.Guild; | |||||
IReadOnlyCollection<IUser> channelUsers = await context.Channel.GetUsersAsync().ConfigureAwait(false); | |||||
IReadOnlyCollection<IGuildUser> guildUsers = null; | |||||
ulong id; | ulong id; | ||||
if (MentionUtils.TryParseUser(input, out id) || ulong.TryParse(input, out id)) | |||||
if (guild != null) | |||||
guildUsers = await guild.GetUsersAsync().ConfigureAwait(false); | |||||
//By Mention (1.0) | |||||
if (MentionUtils.TryParseUser(input, out id)) | |||||
{ | { | ||||
var user = await context.Channel.GetUserAsync(id).ConfigureAwait(false); | |||||
if (user != null) | |||||
result = user; | |||||
if (guild != null) | |||||
AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f); | |||||
else | |||||
AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f); | |||||
} | } | ||||
//By Username + Discriminator | |||||
if (result == null) | |||||
//By Id (0.9) | |||||
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) | |||||
{ | { | ||||
int index = input.LastIndexOf('#'); | |||||
if (index >= 0) | |||||
if (guild != null) | |||||
AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f); | |||||
else | |||||
AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f); | |||||
} | |||||
//By Username + Discriminator (0.7-0.85) | |||||
int index = input.LastIndexOf('#'); | |||||
if (index >= 0) | |||||
{ | |||||
string username = input.Substring(0, index); | |||||
ushort discriminator; | |||||
if (ushort.TryParse(input.Substring(index + 1), out discriminator)) | |||||
{ | { | ||||
string username = input.Substring(0, index); | |||||
ushort discriminator; | |||||
if (ushort.TryParse(input.Substring(index + 1), out discriminator)) | |||||
{ | |||||
var users = await context.Channel.GetUsersAsync().ConfigureAwait(false); | |||||
result = users.Where(x => | |||||
x.DiscriminatorValue == discriminator && | |||||
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); | |||||
} | |||||
var channelUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator && | |||||
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); | |||||
AddResult(results, channelUser as T, channelUser.Username == username ? 0.85f : 0.75f); | |||||
var guildUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator && | |||||
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault(); | |||||
AddResult(results, guildUser as T, guildUser.Username == username ? 0.80f : 0.70f); | |||||
} | } | ||||
} | } | ||||
//By Username | |||||
if (result == null) | |||||
//By Username (0.5-0.6) | |||||
{ | { | ||||
var users = await context.Channel.GetUsersAsync().ConfigureAwait(false); | |||||
var filteredUsers = users.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)).ToArray(); | |||||
if (filteredUsers.Length > 1) | |||||
return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple users found."); | |||||
else if (filteredUsers.Length == 1) | |||||
result = filteredUsers[0]; | |||||
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f); | |||||
foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f); | |||||
} | } | ||||
if (result == null) | |||||
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found."); | |||||
//By Nickname (0.5-0.6) | |||||
{ | |||||
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f); | |||||
foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f); | |||||
} | |||||
T castResult = result as T; | |||||
if (castResult == null) | |||||
return TypeReaderResult.FromError(CommandError.CastFailed, $"User is not a {typeof(T).Name}."); | |||||
else | |||||
return TypeReaderResult.FromSuccess(castResult); | |||||
if (results.Count > 0) | |||||
return TypeReaderResult.FromSuccess(results.Values.ToArray()); | |||||
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found."); | |||||
} | |||||
private void AddResult(Dictionary<ulong, TypeReaderValue> results, T user, float score) | |||||
{ | |||||
if (user != null && !results.ContainsKey(user.Id)) | |||||
results.Add(user.Id, new TypeReaderValue(user, score)); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -6,28 +6,53 @@ namespace Discord.Commands | |||||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | [DebuggerDisplay(@"{DebuggerDisplay,nq}")] | ||||
public struct ParseResult : IResult | public struct ParseResult : IResult | ||||
{ | { | ||||
public IReadOnlyList<object> Values { get; } | |||||
public IReadOnlyList<TypeReaderResult> ArgValues { get; } | |||||
public IReadOnlyList<TypeReaderResult> ParamValues { get; } | |||||
public CommandError? Error { get; } | public CommandError? Error { get; } | ||||
public string ErrorReason { get; } | public string ErrorReason { get; } | ||||
public bool IsSuccess => !Error.HasValue; | public bool IsSuccess => !Error.HasValue; | ||||
private ParseResult(IReadOnlyList<object> values, CommandError? error, string errorReason) | |||||
private ParseResult(IReadOnlyList<TypeReaderResult> argValues, IReadOnlyList<TypeReaderResult> paramValue, CommandError? error, string errorReason) | |||||
{ | { | ||||
Values = values; | |||||
ArgValues = argValues; | |||||
ParamValues = paramValue; | |||||
Error = error; | Error = error; | ||||
ErrorReason = errorReason; | ErrorReason = errorReason; | ||||
} | } | ||||
public static ParseResult FromSuccess(IReadOnlyList<object> values) | |||||
=> new ParseResult(values, null, null); | |||||
public static ParseResult FromSuccess(IReadOnlyList<TypeReaderResult> argValues, IReadOnlyList<TypeReaderResult> paramValues) | |||||
{ | |||||
for (int i = 0; i < argValues.Count; i++) | |||||
{ | |||||
if (argValues[i].Values.Count > 1) | |||||
return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found."); | |||||
} | |||||
for (int i = 0; i < paramValues.Count; i++) | |||||
{ | |||||
if (paramValues[i].Values.Count > 1) | |||||
return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found."); | |||||
} | |||||
return new ParseResult(argValues, paramValues, null, null); | |||||
} | |||||
public static ParseResult FromSuccess(IReadOnlyList<TypeReaderValue> argValues, IReadOnlyList<TypeReaderValue> paramValues) | |||||
{ | |||||
var argList = new TypeReaderResult[argValues.Count]; | |||||
for (int i = 0; i < argValues.Count; i++) | |||||
argList[i] = TypeReaderResult.FromSuccess(argValues[i]); | |||||
var paramList = new TypeReaderResult[paramValues.Count]; | |||||
for (int i = 0; i < paramValues.Count; i++) | |||||
paramList[i] = TypeReaderResult.FromSuccess(paramValues[i]); | |||||
return new ParseResult(argList, paramList, null, null); | |||||
} | |||||
public static ParseResult FromError(CommandError error, string reason) | public static ParseResult FromError(CommandError error, string reason) | ||||
=> new ParseResult(null, error, reason); | |||||
=> new ParseResult(null, null, error, reason); | |||||
public static ParseResult FromError(IResult result) | public static ParseResult FromError(IResult result) | ||||
=> new ParseResult(null, result.Error, result.ErrorReason); | |||||
=> new ParseResult(null, null, result.Error, result.ErrorReason); | |||||
public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | ||||
private string DebuggerDisplay => IsSuccess ? $"Success ({Values.Count} Values)" : $"{Error}: {ErrorReason}"; | |||||
private string DebuggerDisplay => IsSuccess ? $"Success ({ArgValues.Count}{(ParamValues != null ? $" +{ParamValues.Count} Values" : "")})" : $"{Error}: {ErrorReason}"; | |||||
} | } | ||||
} | } |
@@ -1,32 +1,56 @@ | |||||
using System.Diagnostics; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Collections.Immutable; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
namespace Discord.Commands | namespace Discord.Commands | ||||
{ | { | ||||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | [DebuggerDisplay(@"{DebuggerDisplay,nq}")] | ||||
public struct TypeReaderResult : IResult | |||||
public struct TypeReaderValue | |||||
{ | { | ||||
public object Value { get; } | public object Value { get; } | ||||
public float Score { get; } | |||||
public TypeReaderValue(object value, float score) | |||||
{ | |||||
Value = value; | |||||
Score = score; | |||||
} | |||||
public override string ToString() => Value?.ToString(); | |||||
private string DebuggerDisplay => $"[{Value}, {Math.Round(Score, 2)}]"; | |||||
} | |||||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | |||||
public struct TypeReaderResult : IResult | |||||
{ | |||||
public IReadOnlyCollection<TypeReaderValue> Values { get; } | |||||
public CommandError? Error { get; } | public CommandError? Error { get; } | ||||
public string ErrorReason { get; } | public string ErrorReason { get; } | ||||
public bool IsSuccess => !Error.HasValue; | public bool IsSuccess => !Error.HasValue; | ||||
private TypeReaderResult(object value, CommandError? error, string errorReason) | |||||
private TypeReaderResult(IReadOnlyCollection<TypeReaderValue> values, CommandError? error, string errorReason) | |||||
{ | { | ||||
Value = value; | |||||
Values = values; | |||||
Error = error; | Error = error; | ||||
ErrorReason = errorReason; | ErrorReason = errorReason; | ||||
} | } | ||||
public static TypeReaderResult FromSuccess(object value) | public static TypeReaderResult FromSuccess(object value) | ||||
=> new TypeReaderResult(value, null, null); | |||||
=> new TypeReaderResult(ImmutableArray.Create(new TypeReaderValue(value, 1.0f)), null, null); | |||||
public static TypeReaderResult FromSuccess(TypeReaderValue value) | |||||
=> new TypeReaderResult(ImmutableArray.Create(value), null, null); | |||||
public static TypeReaderResult FromSuccess(IReadOnlyCollection<TypeReaderValue> values) | |||||
=> new TypeReaderResult(values, null, null); | |||||
public static TypeReaderResult FromError(CommandError error, string reason) | public static TypeReaderResult FromError(CommandError error, string reason) | ||||
=> new TypeReaderResult(null, error, reason); | => new TypeReaderResult(null, error, reason); | ||||
public static TypeReaderResult FromError(IResult result) | public static TypeReaderResult FromError(IResult result) | ||||
=> new TypeReaderResult(null, result.Error, result.ErrorReason); | => new TypeReaderResult(null, result.Error, result.ErrorReason); | ||||
public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | ||||
private string DebuggerDisplay => IsSuccess ? $"Success ({Value})" : $"{Error}: {ErrorReason}"; | |||||
private string DebuggerDisplay => IsSuccess ? $"Success ({string.Join(", ", Values)})" : $"{Error}: {ErrorReason}"; | |||||
} | } | ||||
} | } |