diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index e27a5e3c..ff43c206 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -23,6 +23,15 @@ namespace Tensorflow { public Tensor log(Tensor x, string name = null) => gen_math_ops.log(x, name); + + /// + /// Computes the Gauss error function of `x` element-wise. + /// + /// + /// + /// + public Tensor erf(Tensor x, string name = null) + => math_ops.erf(x, name); } public Tensor abs(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 391ad9d5..eabd5cd1 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -265,6 +265,29 @@ namespace Tensorflow public static Tensor equal(Tx x, Ty y, string name = null) => gen_math_ops.equal(x, y, name: name); + /// + /// Computes the Gauss error function of `x` element-wise. + /// + /// + /// + /// + public static Tensor erf(Tensor x, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Erf", name, new { x }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Erf", name, + null, + x).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("Erf", op.inputs, attrs, op.outputs); + }, + new Tensors(x)); + public static Tensor sqrt(Tensor x, string name = null) => gen_math_ops.sqrt(x, name: name); diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs index 26e89404..78f57b20 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs @@ -48,5 +48,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var x5 = tf.reduce_sum(b, (0, 1)); Assert.AreEqual(-4.7f, (float)x5); } + + [TestMethod] + public void Erf() + { + var erf = tf.math.erf(a, name: "erf"); + var expected = new float[] { 0.8427007f, -0.5204999f, 0.99999845f, -0.9970206f, 0f, -1f }; + var actual = erf.ToArray(); + Assert.IsTrue(Equal(expected, actual)); + } } }