diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs new file mode 100644 index 000000000..61e6b2bc8 --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireNsfwAttribute.cs @@ -0,0 +1,20 @@ +using System; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + /// + /// Require that the command is invoked in a channel marked NSFW + /// + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + public class RequireNsfwAttribute : PreconditionAttribute + { + public override Task CheckPermissions(ICommandContext context, CommandInfo command, IDependencyMap map) + { + if (context.Channel.Nsfw) + return Task.FromResult(PreconditionResult.FromSuccess()); + else + return Task.FromResult(PreconditionResult.FromError("This command may only be invoked in an NSFW channel.")); + } + } +} diff --git a/src/Discord.Net.Core/Entities/Channels/IChannel.cs b/src/Discord.Net.Core/Entities/Channels/IChannel.cs index 72608ec6a..23d1893a9 100644 --- a/src/Discord.Net.Core/Entities/Channels/IChannel.cs +++ b/src/Discord.Net.Core/Entities/Channels/IChannel.cs @@ -8,6 +8,9 @@ namespace Discord /// Gets the name of this channel. string Name { get; } + /// Checks if the channel is NSFW. + bool Nsfw { get; } + /// Gets a collection of all users in this channel. IAsyncEnumerable> GetUsersAsync(CacheMode mode = CacheMode.AllowDownload, RequestOptions options = null); diff --git a/src/Discord.Net.Core/Utils/NsfwUtils.cs b/src/Discord.Net.Core/Utils/NsfwUtils.cs new file mode 100644 index 000000000..cda461ecc --- /dev/null +++ b/src/Discord.Net.Core/Utils/NsfwUtils.cs @@ -0,0 +1,10 @@ +namespace Discord +{ + public static class NsfwUtils + { + public static bool IsNsfw(IChannel channel) => + IsNsfw(channel.Name); + public static bool IsNsfw(string channelName) => + channelName.StartsWith("nsfw"); + } +} diff --git a/src/Discord.Net.Rest/Entities/Channels/RestChannel.cs b/src/Discord.Net.Rest/Entities/Channels/RestChannel.cs index bc521784d..bc1fc5158 100644 --- a/src/Discord.Net.Rest/Entities/Channels/RestChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/RestChannel.cs @@ -46,6 +46,7 @@ namespace Discord.Rest //IChannel string IChannel.Name => null; + bool IChannel.Nsfw => NsfwUtils.IsNsfw(this); Task IChannel.GetUserAsync(ulong id, CacheMode mode, RequestOptions options) => Task.FromResult(null); //Overriden diff --git a/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs b/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs index 7e515978d..b12bb009c 100644 --- a/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs +++ b/src/Discord.Net.Rest/Entities/Channels/RpcVirtualMessageChannel.cs @@ -97,6 +97,7 @@ namespace Discord.Rest //IChannel string IChannel.Name { get { throw new NotSupportedException(); } } + bool IChannel.Nsfw { get { throw new NotSupportedException(); } } IAsyncEnumerable> IChannel.GetUsersAsync(CacheMode mode, RequestOptions options) { throw new NotSupportedException(); diff --git a/src/Discord.Net.Rpc/Entities/Channels/RpcChannel.cs b/src/Discord.Net.Rpc/Entities/Channels/RpcChannel.cs index cca559a31..7a22a4e6c 100644 --- a/src/Discord.Net.Rpc/Entities/Channels/RpcChannel.cs +++ b/src/Discord.Net.Rpc/Entities/Channels/RpcChannel.cs @@ -7,6 +7,7 @@ namespace Discord.Rpc public class RpcChannel : RpcEntity { public string Name { get; private set; } + public bool Nsfw => NsfwUtils.IsNsfw(Name); public DateTimeOffset CreatedAt => SnowflakeUtils.FromSnowflake(Id); diff --git a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs index 319e17c50..340ab91c1 100644 --- a/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs +++ b/src/Discord.Net.WebSocket/Entities/Channels/SocketChannel.cs @@ -40,6 +40,7 @@ namespace Discord.WebSocket //IChannel string IChannel.Name => null; + bool IChannel.Nsfw => NsfwUtils.IsNsfw(this); Task IChannel.GetUserAsync(ulong id, CacheMode mode, RequestOptions options) => Task.FromResult(null); //Overridden