From 78d17dea394ff8a8e66168a5416448faa0bd9464 Mon Sep 17 00:00:00 2001 From: lsylusiyao Date: Mon, 8 Feb 2021 18:13:50 +0800 Subject: [PATCH] Add random seed test to help reproduce training Currently the tests are set ignored because here's bug --- .../Basics/RandomTest.cs | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 test/TensorFlowNET.UnitTest/Basics/RandomTest.cs diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs new file mode 100644 index 00000000..ed70fa35 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -0,0 +1,106 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using System; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class RandomTest + { + /// + /// Test the function of setting random seed + /// This will help regenerate the same result + /// + [TestMethod, Ignore] + public void TFRandomSeedTest() + { + var initValue = np.arange(6).reshape(3, 2); + tf.set_random_seed(1234); + var a1 = tf.random_uniform(1); + var b1 = tf.random_shuffle(tf.constant(initValue)); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random_uniform(1); + tf.random_shuffle(tf.constant(initValue)); + + tf.set_random_seed(1234); + var a2 = tf.random_uniform(1); + var b2 = tf.random_shuffle(tf.constant(initValue)); + Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); + Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + } + + /// + /// compare to Test above, seed is also added in params + /// + [TestMethod, Ignore] + public void TFRandomSeedTest2() + { + var initValue = np.arange(6).reshape(3, 2); + tf.set_random_seed(1234); + var a1 = tf.random_uniform(1, seed:1234); + var b1 = tf.random_shuffle(tf.constant(initValue), seed: 1234); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random_uniform(1); + tf.random_shuffle(tf.constant(initValue)); + + tf.set_random_seed(1234); + var a2 = tf.random_uniform(1); + var b2 = tf.random_shuffle(tf.constant(initValue)); + Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); + Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + } + + /// + /// This part we use funcs in tf.random rather than only tf + /// + [TestMethod, Ignore] + public void TFRandomRaodomSeedTest() + { + tf.set_random_seed(1234); + var a1 = tf.random.normal(1); + var b1 = tf.random.truncated_normal(1); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random.normal(1); + tf.random.truncated_normal(1); + + tf.set_random_seed(1234); + var a2 = tf.random.normal(1); + var b2 = tf.random.truncated_normal(1); + + Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); + Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + } + + /// + /// compare to Test above, seed is also added in params + /// + [TestMethod, Ignore] + public void TFRandomRaodomSeedTest2() + { + tf.set_random_seed(1234); + var a1 = tf.random.normal(1, seed:1234); + var b1 = tf.random.truncated_normal(1); + + // This part we consider to be a refresh + tf.set_random_seed(10); + tf.random.normal(1); + tf.random.truncated_normal(1); + + tf.set_random_seed(1234); + var a2 = tf.random.normal(1, seed:1234); + var b2 = tf.random.truncated_normal(1, seed:1234); + + Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); + Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); + } + } +} \ No newline at end of file