You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RandomTest.cs 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Linq;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.Basics
  8. {
  9. [TestClass]
  10. public class RandomTest
  11. {
  12. /// <summary>
  13. /// Test the function of setting random seed
  14. /// This will help regenerate the same result
  15. /// </summary>
  16. [TestMethod]
  17. public void TFRandomSeedTest()
  18. {
  19. var initValue = np.arange(6).reshape((3, 2));
  20. tf.set_random_seed(1234);
  21. var a1 = tf.random_uniform(1);
  22. var b1 = tf.random_shuffle(tf.constant(initValue));
  23. // This part we consider to be a refresh
  24. tf.set_random_seed(10);
  25. tf.random_uniform(1);
  26. tf.random_shuffle(tf.constant(initValue));
  27. tf.set_random_seed(1234);
  28. var a2 = tf.random_uniform(1);
  29. var b2 = tf.random_shuffle(tf.constant(initValue));
  30. Assert.AreEqual(a1.numpy(), a2.numpy());
  31. Assert.AreEqual(b1.numpy(), b2.numpy());
  32. }
  33. /// <summary>
  34. /// compare to Test above, seed is also added in params
  35. /// </summary>
  36. [TestMethod, Ignore]
  37. public void TFRandomSeedTest2()
  38. {
  39. var initValue = np.arange(6).reshape((3, 2));
  40. tf.set_random_seed(1234);
  41. var a1 = tf.random_uniform(1, seed:1234);
  42. var b1 = tf.random_shuffle(tf.constant(initValue), seed: 1234);
  43. // This part we consider to be a refresh
  44. tf.set_random_seed(10);
  45. tf.random_uniform(1);
  46. tf.random_shuffle(tf.constant(initValue));
  47. tf.set_random_seed(1234);
  48. var a2 = tf.random_uniform(1);
  49. var b2 = tf.random_shuffle(tf.constant(initValue));
  50. Assert.AreEqual(a1, a2);
  51. Assert.AreEqual(b1, b2);
  52. }
  53. /// <summary>
  54. /// This part we use funcs in tf.random rather than only tf
  55. /// </summary>
  56. [TestMethod]
  57. public void TFRandomRaodomSeedTest()
  58. {
  59. tf.set_random_seed(1234);
  60. var a1 = tf.random.normal(1);
  61. var b1 = tf.random.truncated_normal(1);
  62. // This part we consider to be a refresh
  63. tf.set_random_seed(10);
  64. tf.random.normal(1);
  65. tf.random.truncated_normal(1);
  66. tf.set_random_seed(1234);
  67. var a2 = tf.random.normal(1);
  68. var b2 = tf.random.truncated_normal(1);
  69. Assert.AreEqual(a1.numpy(), a2.numpy());
  70. Assert.AreEqual(b1.numpy(), b2.numpy());
  71. }
  72. /// <summary>
  73. /// compare to Test above, seed is also added in params
  74. /// </summary>
  75. [TestMethod, Ignore]
  76. public void TFRandomRaodomSeedTest2()
  77. {
  78. tf.set_random_seed(1234);
  79. var a1 = tf.random.normal(1, seed:1234);
  80. var b1 = tf.random.truncated_normal(1);
  81. // This part we consider to be a refresh
  82. tf.set_random_seed(10);
  83. tf.random.normal(1);
  84. tf.random.truncated_normal(1);
  85. tf.set_random_seed(1234);
  86. var a2 = tf.random.normal(1, seed:1234);
  87. var b2 = tf.random.truncated_normal(1, seed:1234);
  88. Assert.AreEqual(a1, a2);
  89. Assert.AreEqual(b1, b2);
  90. }
  91. }
  92. }