using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using System; using System.Linq; using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics { [TestClass] public class RandomTest { /// /// Test the function of setting random seed /// This will help regenerate the same result /// [TestMethod] public void TFRandomSeedTest() { var initValue = np.arange(6).reshape(3, 2); tf.set_random_seed(1234); var a1 = tf.random_uniform(1); var b1 = tf.random_shuffle(tf.constant(initValue)); // This part we consider to be a refresh tf.set_random_seed(10); tf.random_uniform(1); tf.random_shuffle(tf.constant(initValue)); tf.set_random_seed(1234); var a2 = tf.random_uniform(1); var b2 = tf.random_shuffle(tf.constant(initValue)); Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); } /// /// compare to Test above, seed is also added in params /// [TestMethod, Ignore] public void TFRandomSeedTest2() { var initValue = np.arange(6).reshape(3, 2); tf.set_random_seed(1234); var a1 = tf.random_uniform(1, seed:1234); var b1 = tf.random_shuffle(tf.constant(initValue), seed: 1234); // This part we consider to be a refresh tf.set_random_seed(10); tf.random_uniform(1); tf.random_shuffle(tf.constant(initValue)); tf.set_random_seed(1234); var a2 = tf.random_uniform(1); var b2 = tf.random_shuffle(tf.constant(initValue)); Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); } /// /// This part we use funcs in tf.random rather than only tf /// [TestMethod] public void TFRandomRaodomSeedTest() { tf.set_random_seed(1234); var a1 = tf.random.normal(1); var b1 = tf.random.truncated_normal(1); // This part we consider to be a refresh tf.set_random_seed(10); tf.random.normal(1); tf.random.truncated_normal(1); tf.set_random_seed(1234); var a2 = tf.random.normal(1); var b2 = tf.random.truncated_normal(1); Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); } /// /// compare to Test above, seed is also added in params /// [TestMethod, Ignore] public void TFRandomRaodomSeedTest2() { tf.set_random_seed(1234); var a1 = tf.random.normal(1, seed:1234); var b1 = tf.random.truncated_normal(1); // This part we consider to be a refresh tf.set_random_seed(10); tf.random.normal(1); tf.random.truncated_normal(1); tf.set_random_seed(1234); var a2 = tf.random.normal(1, seed:1234); var b2 = tf.random.truncated_normal(1, seed:1234); Assert.IsTrue(a1.numpy().array_equal(a2.numpy())); Assert.IsTrue(b1.numpy().array_equal(b2.numpy())); } } }