Browse Source

np.random.normal

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
cb4b248200
5 changed files with 22 additions and 1 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs
  2. +2
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  3. +3
    -0
      src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
  4. +6
    -0
      test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
  5. +10
    -0
      test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs

+ 1
- 1
src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs View File

@@ -37,6 +37,6 @@ namespace Tensorflow.NumPy
}

public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null)
=> throw new NotImplementedException("");
=> new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale));
}
}

+ 2
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -47,6 +47,8 @@ namespace Tensorflow.NumPy
return GetData(mask.ToArray<int>());
else if (mask.dtype == TF_DataType.TF_INT64)
return GetData(mask.ToArray<long>().Select(x => Convert.ToInt32(x)).ToArray());
else if (mask.dtype == TF_DataType.TF_FLOAT)
return GetData(mask.ToArray<float>().Select(x => Convert.ToInt32(x)).ToArray());

throw new NotImplementedException("");
}


+ 3
- 0
src/TensorFlowNET.Core/NumPy/Numpy.Math.cs View File

@@ -18,6 +18,9 @@ namespace Tensorflow.NumPy
[AutoNumPy]
public static NDArray log(NDArray x) => new NDArray(tf.log(x));

[AutoNumPy]
public static NDArray mean(NDArray x) => new NDArray(math_ops.reduce_mean(x));

[AutoNumPy]
public static NDArray multiply(NDArray x1, NDArray x2) => new NDArray(tf.multiply(x1, x2));



+ 6
- 0
test/TensorFlowNET.UnitTest/EagerModeTestBase.cs View File

@@ -14,6 +14,12 @@ namespace TensorFlowNET.UnitTest
tf.Context.ensure_initialized();
}

public bool Equal(float f1, float f2)
{
var tolerance = .000001f;
return Math.Abs(f1 - f2) <= tolerance;
}

public bool Equal(float[] f1, float[] f2)
{
bool ret = false;


+ 10
- 0
test/TensorFlowNET.UnitTest/NumPy/Randomize.Test.cs View File

@@ -23,5 +23,15 @@ namespace TensorFlowNET.UnitTest.NumPy
Assert.AreEqual(x.shape, 10);
Assert.AreNotEqual(x.ToArray<int>(), y.ToArray<int>());
}

/// <summary>
/// https://numpy.org/doc/stable/reference/random/generated/numpy.random.normal.html
/// </summary>
[TestMethod]
public void normal()
{
var x = np.random.normal(0, 0.1f, 1000);
Equal(np.mean(x), 0f);
}
}
}

Loading…
Cancel
Save