diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs index 4a3a309e..c72bd2de 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs @@ -37,6 +37,6 @@ namespace Tensorflow.NumPy } public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null) - => throw new NotImplementedException(""); + => new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale)); } } diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 28471854..d186b400 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -47,6 +47,8 @@ namespace Tensorflow.NumPy return GetData(mask.ToArray()); else if (mask.dtype == TF_DataType.TF_INT64) return GetData(mask.ToArray().Select(x => Convert.ToInt32(x)).ToArray()); + else if (mask.dtype == TF_DataType.TF_FLOAT) + return GetData(mask.ToArray().Select(x => Convert.ToInt32(x)).ToArray()); throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs index b4f1e2f9..39d02dcd 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs @@ -18,6 +18,9 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray log(NDArray x) => new NDArray(tf.log(x)); + [AutoNumPy] + public static NDArray mean(NDArray x) => new NDArray(math_ops.reduce_mean(x)); + [AutoNumPy] public static NDArray multiply(NDArray x1, NDArray x2) => new NDArray(tf.multiply(x1, x2)); diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs index 13c5b141..d08f4e50 100644 --- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -14,6 +14,12 @@ namespace TensorFlowNET.UnitTest tf.Context.ensure_initialized(); } + public bool Equal(float f1, float f2) + { + var tolerance = .000001f; + return Math.Abs(f1 - f2) <= tolerance; + } + public bool Equal(float[] f1, float[] f2) { bool ret = false; diff --git a/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs index 38a4fbbe..5916324f 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs @@ -23,5 +23,15 @@ namespace TensorFlowNET.UnitTest.NumPy Assert.AreEqual(x.shape, 10); Assert.AreNotEqual(x.ToArray(), y.ToArray()); } + + /// + /// https://numpy.org/doc/stable/reference/random/generated/numpy.random.normal.html + /// + [TestMethod] + public void normal() + { + var x = np.random.normal(0, 0.1f, 1000); + Equal(np.mean(x), 0f); + } } }