@@ -275,28 +275,23 @@ namespace Discord.Commands | |||
if (builder.TypeReader == null) | |||
{ | |||
builder.TypeReader = service.GetTypeReaders(paramType)?.FirstOrDefault().Value | |||
builder.TypeReader = service.GetTypeReaders(paramType, false)?.FirstOrDefault().Value | |||
?? service.GetDefaultTypeReader(paramType); | |||
} | |||
} | |||
internal static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services) | |||
{ | |||
var readers = service.GetTypeReaders(paramType); | |||
TypeReader reader = null; | |||
var readers = service.GetTypeReaders(paramType, true); | |||
if (readers != null) | |||
{ | |||
if (readers.TryGetValue(typeReaderType, out reader)) | |||
return reader; | |||
} | |||
var overrideTypeReader = service.GetOverrideTypeReader(paramType); | |||
if (overrideTypeReader != null) | |||
return overrideTypeReader; | |||
foreach (var kvp in readers) | |||
if (kvp.Key == typeReaderType) | |||
return kvp.Value; | |||
//We dont have a cached type reader, create one | |||
reader = ReflectionUtils.CreateObject<TypeReader>(typeReaderType.GetTypeInfo(), service, services); | |||
service.AddOverrideTypeReader(paramType, reader); | |||
TypeReader reader = ReflectionUtils.CreateObject<TypeReader>(typeReaderType.GetTypeInfo(), service, services); | |||
reader.IsOverride = true; | |||
service.AddTypeReader(paramType, reader); | |||
return reader; | |||
} | |||
@@ -60,7 +60,7 @@ namespace Discord.Commands.Builders | |||
if (type.GetTypeInfo().GetCustomAttribute<NamedArgumentTypeAttribute>() != null) | |||
{ | |||
IsRemainder = true; | |||
var reader = commands.GetTypeReaders(type)?.FirstOrDefault().Value; | |||
var reader = commands.GetTypeReaders(type, false)?.FirstOrDefault().Value; | |||
if (reader == null) | |||
{ | |||
Type readerType; | |||
@@ -80,8 +80,7 @@ namespace Discord.Commands.Builders | |||
return reader; | |||
} | |||
var readers = commands.GetTypeReaders(type); | |||
var readers = commands.GetTypeReaders(type, false); | |||
if (readers != null) | |||
return readers.FirstOrDefault().Value; | |||
else | |||
@@ -50,7 +50,6 @@ namespace Discord.Commands | |||
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders; | |||
private readonly ConcurrentDictionary<Type, ConcurrentQueue<Type>> _userEntityTypeReaders; | |||
private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders; | |||
private readonly ConcurrentDictionary<Type, TypeReader> _overrideTypeReaders; | |||
private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders; | |||
private readonly HashSet<ModuleInfo> _moduleDefs; | |||
private readonly CommandMap _map; | |||
@@ -121,7 +120,6 @@ namespace Discord.Commands | |||
_map = new CommandMap(this); | |||
_typeReaders = new ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>>(); | |||
_userEntityTypeReaders = new ConcurrentDictionary<Type, ConcurrentQueue<Type>>(); | |||
_overrideTypeReaders = new ConcurrentDictionary<Type, TypeReader>(); | |||
_defaultTypeReaders = new ConcurrentDictionary<Type, TypeReader>(); | |||
foreach (var type in PrimitiveParsers.SupportedTypes) | |||
@@ -449,20 +447,10 @@ namespace Discord.Commands | |||
var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); | |||
readers[nullableReader.GetType()] = nullableReader; | |||
} | |||
internal void AddOverrideTypeReader(Type valueType, TypeReader valueTypeReader) | |||
{ | |||
_overrideTypeReaders[valueType] = valueTypeReader; | |||
} | |||
internal TypeReader GetOverrideTypeReader(Type type) | |||
{ | |||
if (_overrideTypeReaders.TryGetValue(type, out var definedTypeReader)) | |||
return definedTypeReader; | |||
return null; | |||
} | |||
internal IDictionary<Type, TypeReader> GetTypeReaders(Type type) | |||
internal IEnumerable<KeyValuePair<Type, TypeReader>> GetTypeReaders(Type type, bool includeOverride) | |||
{ | |||
if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) | |||
return definedTypeReaders; | |||
return includeOverride ? definedTypeReaders : definedTypeReaders.Where(x => !x.Value.IsOverride); | |||
var assignableEntityReaders = _userEntityTypeReaders.Where(x => x.Key.IsAssignableFrom(type)); | |||
@@ -490,7 +478,7 @@ namespace Discord.Commands | |||
var entityTypeReaderType = entityReaders.Value.Value.First(); | |||
TypeReader reader = Activator.CreateInstance(entityTypeReaderType.MakeGenericType(type)) as TypeReader; | |||
AddTypeReader(type, reader); | |||
return GetTypeReaders(type); | |||
return GetTypeReaders(type, false); | |||
} | |||
return null; | |||
} | |||
@@ -136,8 +136,8 @@ namespace Discord.Commands | |||
var overridden = prop.GetCustomAttribute<OverrideTypeReaderAttribute>(); | |||
var reader = (overridden != null) | |||
? ModuleClassBuilder.GetTypeReader(_commands, elemType, overridden.TypeReader, services) | |||
: (_commands.GetDefaultTypeReader(elemType) | |||
?? _commands.GetTypeReaders(elemType).FirstOrDefault().Value); | |||
: (_commands.GetTypeReaders(elemType, false)?.FirstOrDefault().Value | |||
?? _commands.GetDefaultTypeReader(elemType)); | |||
if (reader != null) | |||
{ | |||
@@ -8,6 +8,7 @@ namespace Discord.Commands | |||
/// </summary> | |||
public abstract class TypeReader | |||
{ | |||
internal bool IsOverride { get; set; } = false; | |||
/// <summary> | |||
/// Attempts to parse the <paramref name="input"/> into the desired type. | |||
/// </summary> | |||