diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 61b7caf4..c7aa4670 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.Operations;
+
namespace Tensorflow
{
public partial class tensorflow
@@ -50,6 +52,12 @@ namespace Tensorflow
public Tensor sum(Tensor x, Axis? axis = null, string name = null)
=> math_ops.reduce_sum(x, axis: axis, name: name);
+ public Tensor softplus(Tensor features, string name = null)
+ => nn_ops.softplus(features, name: name);
+
+ public Tensor tanh(Tensor x, string name = null)
+ => math_ops.tanh(x, name: name);
+
///
/// Finds values and indices of the `k` largest entries for the last dimension.
///
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index d461595b..15b72f55 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -120,6 +120,16 @@ namespace Tensorflow.Gradients
};
}
+ [RegisterGradient("Softplus")]
+ public static Tensor[] _SoftplusGrad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var x = op.inputs[0];
+
+ var softplus = grad * math_ops.sigmoid(x);
+ return new Tensor[] { softplus };
+ }
+
[RegisterGradient("SquaredDifference")]
public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
{
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index 5877d234..b8d5103c 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -132,6 +132,9 @@ namespace Tensorflow
return _softmax(logits, gen_nn_ops.softmax, axis, name);
}
+ public static Tensor softplus(Tensor features, string name = null)
+ => tf.Context.ExecuteOp("Softplus", name, new ExecuteOpArgs(features));
+
public static Tensor l2_loss(Tensor t, string name = null)
=> tf.Context.ExecuteOp("L2Loss", name, new ExecuteOpArgs(t));
diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs
index 444c783e..37bddac7 100644
--- a/src/TensorFlowNET.Keras/Activations.cs
+++ b/src/TensorFlowNET.Keras/Activations.cs
@@ -20,12 +20,14 @@ namespace Tensorflow.Keras
=> tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features));
private static Activation _tanh = (features, name)
=> tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
+ private static Activation _mish = (features, name)
+ => features * tf.math.tanh(tf.math.softplus(features));
///
/// Register the name-activation mapping in this static class.
///
///
- ///
+ ///
private static void RegisterActivation(string name, Activation activation)
{
_nameActivationMap[name] = activation;
@@ -42,6 +44,7 @@ namespace Tensorflow.Keras
RegisterActivation("sigmoid", _sigmoid);
RegisterActivation("softmax", _softmax);
RegisterActivation("tanh", _tanh);
+ RegisterActivation("mish", _mish);
}
public Activation Linear => _linear;
@@ -54,6 +57,7 @@ namespace Tensorflow.Keras
public Activation Tanh => _tanh;
+ public Activation Mish => _mish;
public static Activation GetActivationByName(string name)
{
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
index 904601b3..1f45c518 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
@@ -94,5 +94,16 @@ namespace TensorFlowNET.Keras.UnitTest {
NDArray expected = new NDArray(new float[] { -0.14227762f, -0.23840584f, -0.26894143f, 0f, 0.7310586f, 1.761594f });
Assert.AreEqual(expected, output.numpy());
}
+
+ ///
+ /// https://www.tensorflow.org/addons/api_docs/python/tfa/activations/mish
+ ///
+ [TestMethod]
+ public void Mish()
+ {
+ var x = tf.constant(new[] { 1.0, 0.0, 1.0 }, dtype: tf.float32);
+ var output = keras.activations.Mish(x);
+ Assert.AreEqual(new[] { 0.86509836f, 0f, 0.86509836f }, output.numpy());
+ }
}
}