You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 4.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Python;
  4. namespace TensorFlowNET.UnitTest
  5. {
  6. [TestClass]
  7. public class VariableTest
  8. {
  9. [TestMethod]
  10. public void Initializer()
  11. {
  12. var x = tf.Variable(10, name: "x");
  13. using (var session = tf.Session())
  14. {
  15. session.run(x.initializer);
  16. var result = session.run(x);
  17. Assert.AreEqual(10, (int)result[0]);
  18. }
  19. }
  20. [TestMethod]
  21. public void StringVar()
  22. {
  23. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars);
  24. var mammal2 = tf.Variable("Tiger");
  25. }
  26. /// <summary>
  27. /// https://www.tensorflow.org/api_docs/python/tf/variable_scope
  28. /// how to create a new variable
  29. /// </summary>
  30. [TestMethod]
  31. public void VarCreation()
  32. {
  33. tf.Graph().as_default();
  34. tf_with(tf.variable_scope("foo"), delegate
  35. {
  36. tf_with(tf.variable_scope("bar"), delegate
  37. {
  38. var v = tf.get_variable("v", new TensorShape(1));
  39. Assert.AreEqual(v.name, "foo/bar/v:0");
  40. });
  41. });
  42. }
  43. /// <summary>
  44. /// how to reenter a premade variable scope safely
  45. /// </summary>
  46. [TestMethod]
  47. public void ReenterVariableScope()
  48. {
  49. tf.Graph().as_default();
  50. variable_scope vs = null;
  51. tf_with(tf.variable_scope("foo"), v => vs = v);
  52. // Re-enter the variable scope.
  53. tf_with(tf.variable_scope(vs, auxiliary_name_scope: false), v =>
  54. {
  55. var vs1 = (VariableScope)v;
  56. // Restore the original name_scope.
  57. tf_with(tf.name_scope(vs1.original_name_scope), delegate
  58. {
  59. var v1 = tf.get_variable("v", new TensorShape(1));
  60. Assert.AreEqual(v1.name, "foo/v:0");
  61. var c1 = tf.constant(new int[] { 1 }, name: "c");
  62. Assert.AreEqual(c1.name, "foo/c:0");
  63. });
  64. });
  65. }
  66. [TestMethod]
  67. public void ScalarVar()
  68. {
  69. var x = tf.constant(3, name: "x");
  70. var y = tf.Variable(x + 1, name: "y");
  71. var model = tf.global_variables_initializer();
  72. using (var session = tf.Session())
  73. {
  74. session.run(model);
  75. int result = session.run(y)[0];
  76. Assert.AreEqual(result, 4);
  77. }
  78. }
  79. [TestMethod]
  80. public void Assign1()
  81. {
  82. var graph = tf.Graph().as_default();
  83. var variable = tf.Variable(31, name: "tree");
  84. var init = tf.global_variables_initializer();
  85. var sess = tf.Session(graph);
  86. sess.run(init);
  87. var result = sess.run(variable);
  88. Assert.IsTrue((int)result[0] == 31);
  89. var assign = variable.assign(12);
  90. result = sess.run(assign);
  91. Assert.IsTrue((int)result[0] == 12);
  92. }
  93. [TestMethod]
  94. public void Assign2()
  95. {
  96. var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
  97. var inc_v1 = v1.assign(v1 + 1.0f);
  98. // Add an op to initialize the variables.
  99. var init_op = tf.global_variables_initializer();
  100. using (var sess = tf.Session())
  101. {
  102. sess.run(init_op);
  103. // o some work with the model.
  104. inc_v1.op.run();
  105. }
  106. }
  107. /// <summary>
  108. /// https://databricks.com/tensorflow/variables
  109. /// </summary>
  110. [TestMethod]
  111. public void Add()
  112. {
  113. tf.Graph().as_default();
  114. int result = 0;
  115. Tensor x = tf.Variable(10, name: "x");
  116. var init_op = tf.global_variables_initializer();
  117. using (var session = tf.Session())
  118. {
  119. session.run(init_op);
  120. for(int i = 0; i < 5; i++)
  121. {
  122. x = x + 1;
  123. result = session.run(x)[0];
  124. print(result);
  125. }
  126. }
  127. Assert.AreEqual(15, result);
  128. }
  129. }
  130. }