Browse Source

Miscellaneous improvements to commands

- OverloadInfo in Before/AfterExecute
  Now you know *exactly* what command is being executed.
- TypeReaders are given the IServiceProvider
  Now you can write TypeReaders for DB-backed data
pull/678/head
FiniteReality 8 years ago
parent
commit
f8e505e304
15 changed files with 39 additions and 36 deletions
  1. +4
    -3
      src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs
  2. +1
    -1
      src/Discord.Net.Commands/Builders/OverloadBuilder.cs
  3. +5
    -4
      src/Discord.Net.Commands/CommandParser.cs
  4. +4
    -7
      src/Discord.Net.Commands/CommandService.cs
  5. +2
    -2
      src/Discord.Net.Commands/IModuleBase.cs
  6. +4
    -4
      src/Discord.Net.Commands/Info/OverloadInfo.cs
  7. +4
    -2
      src/Discord.Net.Commands/Info/ParameterInfo.cs
  8. +4
    -4
      src/Discord.Net.Commands/ModuleBase.cs
  9. +1
    -1
      src/Discord.Net.Commands/Readers/ChannelTypeReader.cs
  10. +1
    -1
      src/Discord.Net.Commands/Readers/EnumTypeReader.cs
  11. +3
    -2
      src/Discord.Net.Commands/Readers/MessageTypeReader.cs
  12. +1
    -1
      src/Discord.Net.Commands/Readers/PrimitiveTypeReader.cs
  13. +1
    -1
      src/Discord.Net.Commands/Readers/RoleTypeReader.cs
  14. +3
    -2
      src/Discord.Net.Commands/Readers/TypeReader.cs
  15. +1
    -1
      src/Discord.Net.Commands/Readers/UserTypeReader.cs

+ 4
- 3
src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs View File

@@ -198,19 +198,20 @@ namespace Discord.Commands


var createInstance = ReflectionUtils.CreateBuilder<IModuleBase>(typeInfo, service); var createInstance = ReflectionUtils.CreateBuilder<IModuleBase>(typeInfo, service);


builder.Callback = async (ctx, args, map) =>
builder.Callback = async (ctx, args, map, overload) =>
{ {
var instance = createInstance(map); var instance = createInstance(map);
instance.SetContext(ctx); instance.SetContext(ctx);

try try
{ {
instance.BeforeExecute();
instance.BeforeExecute(overload);
var task = method.Invoke(instance, args) as Task ?? Task.Delay(0); var task = method.Invoke(instance, args) as Task ?? Task.Delay(0);
await task.ConfigureAwait(false); await task.ConfigureAwait(false);
} }
finally finally
{ {
instance.AfterExecute();
instance.AfterExecute(overload);
(instance as IDisposable)?.Dispose(); (instance as IDisposable)?.Dispose();
} }
}; };


+ 1
- 1
src/Discord.Net.Commands/Builders/OverloadBuilder.cs View File

@@ -4,7 +4,7 @@ using System.Threading.Tasks;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;


using CommandCallback = System.Func<Discord.Commands.ICommandContext, object[], System.IServiceProvider, System.Threading.Tasks.Task>;
using CommandCallback = System.Func<Discord.Commands.ICommandContext, object[], System.IServiceProvider, Discord.Commands.OverloadInfo, System.Threading.Tasks.Task>;


namespace Discord.Commands.Builders namespace Discord.Commands.Builders
{ {


+ 5
- 4
src/Discord.Net.Commands/CommandParser.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using System;
using System.Collections.Immutable;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;


@@ -13,7 +14,7 @@ namespace Discord.Commands
QuotedParameter QuotedParameter
} }
public static async Task<ParseResult> ParseArgs(OverloadInfo overload, ICommandContext context, string input, int startPos)
public static async Task<ParseResult> ParseArgs(OverloadInfo overload, ICommandContext context, IServiceProvider services, string input, int startPos)
{ {
ParameterInfo curParam = null; ParameterInfo curParam = null;
StringBuilder argBuilder = new StringBuilder(input.Length); StringBuilder argBuilder = new StringBuilder(input.Length);
@@ -110,7 +111,7 @@ namespace Discord.Commands
if (curParam == null) if (curParam == null)
return ParseResult.FromError(CommandError.BadArgCount, "The input text has too many parameters."); return ParseResult.FromError(CommandError.BadArgCount, "The input text has too many parameters.");


var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false);
var typeReaderResult = await curParam.Parse(context, argString, services).ConfigureAwait(false);
if (!typeReaderResult.IsSuccess && typeReaderResult.Error != CommandError.MultipleMatches) if (!typeReaderResult.IsSuccess && typeReaderResult.Error != CommandError.MultipleMatches)
return ParseResult.FromError(typeReaderResult); return ParseResult.FromError(typeReaderResult);


@@ -133,7 +134,7 @@ namespace Discord.Commands


if (curParam != null && curParam.IsRemainder) if (curParam != null && curParam.IsRemainder)
{ {
var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false);
var typeReaderResult = await curParam.Parse(context, argBuilder.ToString(), services).ConfigureAwait(false);
if (!typeReaderResult.IsSuccess) if (!typeReaderResult.IsSuccess)
return ParseResult.FromError(typeReaderResult); return ParseResult.FromError(typeReaderResult);
argList.Add(typeReaderResult); argList.Add(typeReaderResult);


+ 4
- 7
src/Discord.Net.Commands/CommandService.cs View File

@@ -168,8 +168,7 @@ namespace Discord.Commands
await _moduleLock.WaitAsync().ConfigureAwait(false); await _moduleLock.WaitAsync().ConfigureAwait(false);
try try
{ {
ModuleInfo module;
if (!_typedModuleDefs.TryRemove(type, out module))
if (!_typedModuleDefs.TryRemove(type, out var module))
return false; return false;


return RemoveModuleInternal(module); return RemoveModuleInternal(module);
@@ -208,15 +207,13 @@ namespace Discord.Commands
} }
internal IDictionary<Type, TypeReader> GetTypeReaders(Type type) internal IDictionary<Type, TypeReader> GetTypeReaders(Type type)
{ {
ConcurrentDictionary<Type, TypeReader> definedTypeReaders;
if (_typeReaders.TryGetValue(type, out definedTypeReaders))
if (_typeReaders.TryGetValue(type, out var definedTypeReaders))
return definedTypeReaders; return definedTypeReaders;
return null; return null;
} }
internal TypeReader GetDefaultTypeReader(Type type) internal TypeReader GetDefaultTypeReader(Type type)
{ {
TypeReader reader;
if (_defaultTypeReaders.TryGetValue(type, out reader))
if (_defaultTypeReaders.TryGetValue(type, out var reader))
return reader; return reader;
var typeInfo = type.GetTypeInfo(); var typeInfo = type.GetTypeInfo();


@@ -287,7 +284,7 @@ namespace Discord.Commands
var rawParseResults = new List<ParseResult>(); var rawParseResults = new List<ParseResult>();
foreach (var overload in overloads) foreach (var overload in overloads)
{ {
rawParseResults.Add(await overload.ParseAsync(context, searchResult, preconditionResult).ConfigureAwait(false));
rawParseResults.Add(await overload.ParseAsync(context, services, searchResult, preconditionResult).ConfigureAwait(false));
} }


//order by average score //order by average score


+ 2
- 2
src/Discord.Net.Commands/IModuleBase.cs View File

@@ -4,8 +4,8 @@
{ {
void SetContext(ICommandContext context); void SetContext(ICommandContext context);


void BeforeExecute();
void BeforeExecute(OverloadInfo overload);
void AfterExecute();
void AfterExecute(OverloadInfo overload);
} }
} }

+ 4
- 4
src/Discord.Net.Commands/Info/OverloadInfo.cs View File

@@ -18,7 +18,7 @@ namespace Discord.Commands
private static readonly MethodInfo _convertParamsMethod = typeof(OverloadInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); private static readonly MethodInfo _convertParamsMethod = typeof(OverloadInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>(); private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>();


private readonly Func<ICommandContext, object[], IServiceProvider, Task> _action;
private readonly Func<ICommandContext, object[], IServiceProvider, OverloadInfo, Task> _action;


public CommandInfo Command { get; } public CommandInfo Command { get; }
public int Priority { get; } public int Priority { get; }
@@ -65,7 +65,7 @@ namespace Discord.Commands
return PreconditionResult.FromSuccess(); return PreconditionResult.FromSuccess();
} }


public async Task<ParseResult> ParseAsync(ICommandContext context, SearchResult searchResult, PreconditionResult? preconditionResult = null)
public async Task<ParseResult> ParseAsync(ICommandContext context, IServiceProvider services, SearchResult searchResult, PreconditionResult? preconditionResult = null)
{ {
if (!searchResult.IsSuccess) if (!searchResult.IsSuccess)
return ParseResult.FromError(searchResult); return ParseResult.FromError(searchResult);
@@ -84,7 +84,7 @@ namespace Discord.Commands


input = input.Substring(matchingAlias.Length); input = input.Substring(matchingAlias.Length);


return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false);
return await CommandParser.ParseArgs(this, context, services, input, 0).ConfigureAwait(false);
} }


public Task<ExecuteResult> ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services) public Task<ExecuteResult> ExecuteAsync(ICommandContext context, ParseResult parseResult, IServiceProvider services)
@@ -140,7 +140,7 @@ namespace Discord.Commands
await Command.Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false); await Command.Module.Service._cmdLogger.DebugAsync($"Executing {GetLogText(context)}").ConfigureAwait(false);
try try
{ {
await _action(context, args, map).ConfigureAwait(false);
await _action(context, args, map, this).ConfigureAwait(false);
} }
catch (Exception ex) catch (Exception ex)
{ {


+ 4
- 2
src/Discord.Net.Commands/Info/ParameterInfo.cs View File

@@ -56,9 +56,11 @@ namespace Discord.Commands
return PreconditionResult.FromSuccess(); return PreconditionResult.FromSuccess();
} }


public async Task<TypeReaderResult> Parse(ICommandContext context, string input)
public async Task<TypeReaderResult> Parse(ICommandContext context, string input, IServiceProvider services)
{ {
return await _reader.Read(context, input).ConfigureAwait(false);
services = services ?? EmptyServiceProvider.Instance;

return await _reader.Read(context, input, services).ConfigureAwait(false);
} }


public override string ToString() => Name; public override string ToString() => Name;


+ 4
- 4
src/Discord.Net.Commands/ModuleBase.cs View File

@@ -15,11 +15,11 @@ namespace Discord.Commands
return await Context.Channel.SendMessageAsync(message, isTTS, embed, options).ConfigureAwait(false); return await Context.Channel.SendMessageAsync(message, isTTS, embed, options).ConfigureAwait(false);
} }


protected virtual void BeforeExecute()
protected virtual void BeforeExecute(OverloadInfo overload)
{ {
} }


protected virtual void AfterExecute()
protected virtual void AfterExecute(OverloadInfo overload)
{ {
} }


@@ -32,8 +32,8 @@ namespace Discord.Commands
Context = newValue; Context = newValue;
} }


void IModuleBase.BeforeExecute() => BeforeExecute();
void IModuleBase.BeforeExecute(OverloadInfo overload) => BeforeExecute(overload);


void IModuleBase.AfterExecute() => AfterExecute();
void IModuleBase.AfterExecute(OverloadInfo overload) => AfterExecute(overload);
} }
} }

+ 1
- 1
src/Discord.Net.Commands/Readers/ChannelTypeReader.cs View File

@@ -9,7 +9,7 @@ namespace Discord.Commands
internal class ChannelTypeReader<T> : TypeReader internal class ChannelTypeReader<T> : TypeReader
where T : class, IChannel where T : class, IChannel
{ {
public override async Task<TypeReaderResult> Read(ICommandContext context, string input)
public override async Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services)
{ {
if (context.Guild != null) if (context.Guild != null)
{ {


+ 1
- 1
src/Discord.Net.Commands/Readers/EnumTypeReader.cs View File

@@ -44,7 +44,7 @@ namespace Discord.Commands
_enumsByValue = byValueBuilder.ToImmutable(); _enumsByValue = byValueBuilder.ToImmutable();
} }


public override Task<TypeReaderResult> Read(ICommandContext context, string input)
public override Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services)
{ {
T baseValue; T baseValue;
object enumValue; object enumValue;


+ 3
- 2
src/Discord.Net.Commands/Readers/MessageTypeReader.cs View File

@@ -1,4 +1,5 @@
using System.Globalization;
using System;
using System.Globalization;
using System.Threading.Tasks; using System.Threading.Tasks;


namespace Discord.Commands namespace Discord.Commands
@@ -6,7 +7,7 @@ namespace Discord.Commands
internal class MessageTypeReader<T> : TypeReader internal class MessageTypeReader<T> : TypeReader
where T : class, IMessage where T : class, IMessage
{ {
public override async Task<TypeReaderResult> Read(ICommandContext context, string input)
public override async Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services)
{ {
ulong id; ulong id;




+ 1
- 1
src/Discord.Net.Commands/Readers/PrimitiveTypeReader.cs View File

@@ -27,7 +27,7 @@ namespace Discord.Commands
_score = score; _score = score;
} }


public override Task<TypeReaderResult> Read(ICommandContext context, string input)
public override Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services)
{ {
T value; T value;
if (_tryParse(input, out value)) if (_tryParse(input, out value))


+ 1
- 1
src/Discord.Net.Commands/Readers/RoleTypeReader.cs View File

@@ -9,7 +9,7 @@ namespace Discord.Commands
internal class RoleTypeReader<T> : TypeReader internal class RoleTypeReader<T> : TypeReader
where T : class, IRole where T : class, IRole
{ {
public override Task<TypeReaderResult> Read(ICommandContext context, string input)
public override Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services)
{ {
ulong id; ulong id;




+ 3
- 2
src/Discord.Net.Commands/Readers/TypeReader.cs View File

@@ -1,9 +1,10 @@
using System.Threading.Tasks;
using System;
using System.Threading.Tasks;


namespace Discord.Commands namespace Discord.Commands
{ {
public abstract class TypeReader public abstract class TypeReader
{ {
public abstract Task<TypeReaderResult> Read(ICommandContext context, string input);
public abstract Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services);
} }
} }

+ 1
- 1
src/Discord.Net.Commands/Readers/UserTypeReader.cs View File

@@ -10,7 +10,7 @@ namespace Discord.Commands
internal class UserTypeReader<T> : TypeReader internal class UserTypeReader<T> : TypeReader
where T : class, IUser where T : class, IUser
{ {
public override async Task<TypeReaderResult> Read(ICommandContext context, string input)
public override async Task<TypeReaderResult> Read(ICommandContext context, string input, IServiceProvider services)
{ {
var results = new Dictionary<ulong, TypeReaderValue>(); var results = new Dictionary<ulong, TypeReaderValue>();
IReadOnlyCollection<IUser> channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way? IReadOnlyCollection<IUser> channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way?


Loading…
Cancel
Save