Browse Source

Add random seed test to help reproduce training

Currently the tests are set ignored because here's bug
tags/v0.40-tf2.4-tstring
lsylusiyao Esther Hu 4 years ago
parent
commit
fb903d8703
1 changed files with 106 additions and 0 deletions
  1. +106
    -0
      test/TensorFlowNET.UnitTest/Basics/RandomTest.cs

+ 106
- 0
test/TensorFlowNET.UnitTest/Basics/RandomTest.cs View File

@@ -0,0 +1,106 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class RandomTest
{
/// <summary>
/// Test the function of setting random seed
/// This will help regenerate the same result
/// </summary>
[TestMethod, Ignore]
public void TFRandomSeedTest()
{
var initValue = np.arange(6).reshape(3, 2);
tf.set_random_seed(1234);
var a1 = tf.random_uniform(1);
var b1 = tf.random_shuffle(tf.constant(initValue));
// This part we consider to be a refresh
tf.set_random_seed(10);
tf.random_uniform(1);
tf.random_shuffle(tf.constant(initValue));
tf.set_random_seed(1234);
var a2 = tf.random_uniform(1);
var b2 = tf.random_shuffle(tf.constant(initValue));
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
}
/// <summary>
/// compare to Test above, seed is also added in params
/// </summary>
[TestMethod, Ignore]
public void TFRandomSeedTest2()
{
var initValue = np.arange(6).reshape(3, 2);
tf.set_random_seed(1234);
var a1 = tf.random_uniform(1, seed:1234);
var b1 = tf.random_shuffle(tf.constant(initValue), seed: 1234);
// This part we consider to be a refresh
tf.set_random_seed(10);
tf.random_uniform(1);
tf.random_shuffle(tf.constant(initValue));
tf.set_random_seed(1234);
var a2 = tf.random_uniform(1);
var b2 = tf.random_shuffle(tf.constant(initValue));
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
}
/// <summary>
/// This part we use funcs in tf.random rather than only tf
/// </summary>
[TestMethod, Ignore]
public void TFRandomRaodomSeedTest()
{
tf.set_random_seed(1234);
var a1 = tf.random.normal(1);
var b1 = tf.random.truncated_normal(1);
// This part we consider to be a refresh
tf.set_random_seed(10);
tf.random.normal(1);
tf.random.truncated_normal(1);
tf.set_random_seed(1234);
var a2 = tf.random.normal(1);
var b2 = tf.random.truncated_normal(1);
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
}
/// <summary>
/// compare to Test above, seed is also added in params
/// </summary>
[TestMethod, Ignore]
public void TFRandomRaodomSeedTest2()
{
tf.set_random_seed(1234);
var a1 = tf.random.normal(1, seed:1234);
var b1 = tf.random.truncated_normal(1);
// This part we consider to be a refresh
tf.set_random_seed(10);
tf.random.normal(1);
tf.random.truncated_normal(1);
tf.set_random_seed(1234);
var a2 = tf.random.normal(1, seed:1234);
var b2 = tf.random.truncated_normal(1, seed:1234);
Assert.IsTrue(a1.numpy().array_equal(a2.numpy()));
Assert.IsTrue(b1.numpy().array_equal(b2.numpy()));
}
}
}

Loading…
Cancel
Save