@@ -275,28 +275,23 @@ namespace Discord.Commands | |||||
if (builder.TypeReader == null) | if (builder.TypeReader == null) | ||||
{ | { | ||||
builder.TypeReader = service.GetTypeReaders(paramType)?.FirstOrDefault().Value | |||||
builder.TypeReader = service.GetTypeReaders(paramType, false)?.FirstOrDefault().Value | |||||
?? service.GetDefaultTypeReader(paramType); | ?? service.GetDefaultTypeReader(paramType); | ||||
} | } | ||||
} | } | ||||
internal static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services) | internal static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services) | ||||
{ | { | ||||
var readers = service.GetTypeReaders(paramType); | |||||
TypeReader reader = null; | |||||
var readers = service.GetTypeReaders(paramType, true); | |||||
if (readers != null) | if (readers != null) | ||||
{ | |||||
if (readers.TryGetValue(typeReaderType, out reader)) | |||||
return reader; | |||||
} | |||||
var overrideTypeReader = service.GetOverrideTypeReader(paramType); | |||||
if (overrideTypeReader != null) | |||||
return overrideTypeReader; | |||||
foreach (var kvp in readers) | |||||
if (kvp.Key == typeReaderType) | |||||
return kvp.Value; | |||||
//We dont have a cached type reader, create one | //We dont have a cached type reader, create one | ||||
reader = ReflectionUtils.CreateObject<TypeReader>(typeReaderType.GetTypeInfo(), service, services); | |||||
service.AddOverrideTypeReader(paramType, reader); | |||||
TypeReader reader = ReflectionUtils.CreateObject<TypeReader>(typeReaderType.GetTypeInfo(), service, services); | |||||
reader.IsOverride = true; | |||||
service.AddTypeReader(paramType, reader); | |||||
return reader; | return reader; | ||||
} | } | ||||
@@ -60,7 +60,7 @@ namespace Discord.Commands.Builders | |||||
if (type.GetTypeInfo().GetCustomAttribute<NamedArgumentTypeAttribute>() != null) | if (type.GetTypeInfo().GetCustomAttribute<NamedArgumentTypeAttribute>() != null) | ||||
{ | { | ||||
IsRemainder = true; | IsRemainder = true; | ||||
var reader = commands.GetTypeReaders(type)?.FirstOrDefault().Value; | |||||
var reader = commands.GetTypeReaders(type, false)?.FirstOrDefault().Value; | |||||
if (reader == null) | if (reader == null) | ||||
{ | { | ||||
Type readerType; | Type readerType; | ||||
@@ -80,8 +80,7 @@ namespace Discord.Commands.Builders | |||||
return reader; | return reader; | ||||
} | } | ||||
var readers = commands.GetTypeReaders(type); | |||||
var readers = commands.GetTypeReaders(type, false); | |||||
if (readers != null) | if (readers != null) | ||||
return readers.FirstOrDefault().Value; | return readers.FirstOrDefault().Value; | ||||
else | else | ||||
@@ -50,7 +50,6 @@ namespace Discord.Commands | |||||
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders; | private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders; | ||||
private readonly ConcurrentDictionary<Type, ConcurrentQueue<Type>> _userEntityTypeReaders; | private readonly ConcurrentDictionary<Type, ConcurrentQueue<Type>> _userEntityTypeReaders; | ||||
private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders; | private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders; | ||||
private readonly ConcurrentDictionary<Type, TypeReader> _overrideTypeReaders; | |||||
private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders; | private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders; | ||||
private readonly HashSet<ModuleInfo> _moduleDefs; | private readonly HashSet<ModuleInfo> _moduleDefs; | ||||
private readonly CommandMap _map; | private readonly CommandMap _map; | ||||
@@ -121,7 +120,6 @@ namespace Discord.Commands | |||||
_map = new CommandMap(this); | _map = new CommandMap(this); | ||||
_typeReaders = new ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>>(); | _typeReaders = new ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>>(); | ||||
_userEntityTypeReaders = new ConcurrentDictionary<Type, ConcurrentQueue<Type>>(); | _userEntityTypeReaders = new ConcurrentDictionary<Type, ConcurrentQueue<Type>>(); | ||||
_overrideTypeReaders = new ConcurrentDictionary<Type, TypeReader>(); | |||||
_defaultTypeReaders = new ConcurrentDictionary<Type, TypeReader>(); | _defaultTypeReaders = new ConcurrentDictionary<Type, TypeReader>(); | ||||
foreach (var type in PrimitiveParsers.SupportedTypes) | foreach (var type in PrimitiveParsers.SupportedTypes) | ||||
@@ -449,20 +447,10 @@ namespace Discord.Commands | |||||
var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); | var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); | ||||
readers[nullableReader.GetType()] = nullableReader; | readers[nullableReader.GetType()] = nullableReader; | ||||
} | } | ||||
internal void AddOverrideTypeReader(Type valueType, TypeReader valueTypeReader) | |||||
{ | |||||
_overrideTypeReaders[valueType] = valueTypeReader; | |||||
} | |||||
internal TypeReader GetOverrideTypeReader(Type type) | |||||
{ | |||||
if (_overrideTypeReaders.TryGetValue(type, out var definedTypeReader)) | |||||
return definedTypeReader; | |||||
return null; | |||||
} | |||||
internal IDictionary<Type, TypeReader> GetTypeReaders(Type type) | |||||
internal IEnumerable<KeyValuePair<Type, TypeReader>> GetTypeReaders(Type type, bool includeOverride) | |||||
{ | { | ||||
if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) | if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) | ||||
return definedTypeReaders; | |||||
return includeOverride ? definedTypeReaders : definedTypeReaders.Where(x => !x.Value.IsOverride); | |||||
var assignableEntityReaders = _userEntityTypeReaders.Where(x => x.Key.IsAssignableFrom(type)); | var assignableEntityReaders = _userEntityTypeReaders.Where(x => x.Key.IsAssignableFrom(type)); | ||||
@@ -490,7 +478,7 @@ namespace Discord.Commands | |||||
var entityTypeReaderType = entityReaders.Value.Value.First(); | var entityTypeReaderType = entityReaders.Value.Value.First(); | ||||
TypeReader reader = Activator.CreateInstance(entityTypeReaderType.MakeGenericType(type)) as TypeReader; | TypeReader reader = Activator.CreateInstance(entityTypeReaderType.MakeGenericType(type)) as TypeReader; | ||||
AddTypeReader(type, reader); | AddTypeReader(type, reader); | ||||
return GetTypeReaders(type); | |||||
return GetTypeReaders(type, false); | |||||
} | } | ||||
return null; | return null; | ||||
} | } | ||||
@@ -136,8 +136,8 @@ namespace Discord.Commands | |||||
var overridden = prop.GetCustomAttribute<OverrideTypeReaderAttribute>(); | var overridden = prop.GetCustomAttribute<OverrideTypeReaderAttribute>(); | ||||
var reader = (overridden != null) | var reader = (overridden != null) | ||||
? ModuleClassBuilder.GetTypeReader(_commands, elemType, overridden.TypeReader, services) | ? ModuleClassBuilder.GetTypeReader(_commands, elemType, overridden.TypeReader, services) | ||||
: (_commands.GetDefaultTypeReader(elemType) | |||||
?? _commands.GetTypeReaders(elemType).FirstOrDefault().Value); | |||||
: (_commands.GetTypeReaders(elemType, false)?.FirstOrDefault().Value | |||||
?? _commands.GetDefaultTypeReader(elemType)); | |||||
if (reader != null) | if (reader != null) | ||||
{ | { | ||||
@@ -8,6 +8,7 @@ namespace Discord.Commands | |||||
/// </summary> | /// </summary> | ||||
public abstract class TypeReader | public abstract class TypeReader | ||||
{ | { | ||||
internal bool IsOverride { get; set; } = false; | |||||
/// <summary> | /// <summary> | ||||
/// Attempts to parse the <paramref name="input"/> into the desired type. | /// Attempts to parse the <paramref name="input"/> into the desired type. | ||||
/// </summary> | /// </summary> | ||||