You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 4.6 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. [TestClass]
  9. public class VariableTest : Python
  10. {
  11. [TestMethod]
  12. public void Initializer()
  13. {
  14. var x = tf.Variable(10, name: "x");
  15. using (var session = tf.Session())
  16. {
  17. session.run(x.initializer);
  18. var result = session.run(x);
  19. Assert.AreEqual(10, (int)result);
  20. }
  21. }
  22. [TestMethod]
  23. public void StringVar()
  24. {
  25. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars);
  26. var mammal2 = tf.Variable("Tiger");
  27. }
  28. /// <summary>
  29. /// https://www.tensorflow.org/api_docs/python/tf/variable_scope
  30. /// how to create a new variable
  31. /// </summary>
  32. [TestMethod]
  33. public void VarCreation()
  34. {
  35. tf.Graph().as_default();
  36. with(tf.variable_scope("foo"), delegate
  37. {
  38. with(tf.variable_scope("bar"), delegate
  39. {
  40. var v = tf.get_variable("v", new TensorShape(1));
  41. Assert.AreEqual(v.name, "foo/bar/v:0");
  42. });
  43. });
  44. }
  45. /// <summary>
  46. /// how to reenter a premade variable scope safely
  47. /// </summary>
  48. [TestMethod]
  49. public void ReenterVariableScope()
  50. {
  51. tf.Graph().as_default();
  52. variable_scope vs = null;
  53. with(tf.variable_scope("foo"), v => vs = v);
  54. // Re-enter the variable scope.
  55. with(tf.variable_scope(vs, auxiliary_name_scope: false), v =>
  56. {
  57. var vs1 = (VariableScope)v;
  58. // Restore the original name_scope.
  59. with(tf.name_scope(vs1.original_name_scope), delegate
  60. {
  61. var v1 = tf.get_variable("v", new TensorShape(1));
  62. Assert.AreEqual(v1.name, "foo/v:0");
  63. var c1 = tf.constant(new int[] { 1 }, name: "c");
  64. Assert.AreEqual(c1.name, "foo/c:0");
  65. });
  66. });
  67. }
  68. [TestMethod]
  69. public void ScalarVar()
  70. {
  71. var x = tf.constant(3, name: "x");
  72. var y = tf.Variable(x + 1, name: "y");
  73. var model = tf.global_variables_initializer();
  74. using (var session = tf.Session())
  75. {
  76. session.run(model);
  77. int result = session.run(y);
  78. Assert.AreEqual(result, 4);
  79. }
  80. }
  81. [TestMethod]
  82. public void Assign1()
  83. {
  84. with(tf.Graph().as_default(), graph =>
  85. {
  86. var variable = tf.Variable(31, name: "tree");
  87. var init = tf.global_variables_initializer();
  88. var sess = tf.Session(graph);
  89. sess.run(init);
  90. var result = sess.run(variable);
  91. Assert.IsTrue((int)result == 31);
  92. var assign = variable.assign(12);
  93. result = sess.run(assign);
  94. Assert.IsTrue((int)result == 12);
  95. });
  96. }
  97. [TestMethod]
  98. public void Assign2()
  99. {
  100. var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
  101. var inc_v1 = v1.assign(v1 + 1.0f);
  102. // Add an op to initialize the variables.
  103. var init_op = tf.global_variables_initializer();
  104. with(tf.Session(), sess =>
  105. {
  106. sess.run(init_op);
  107. // o some work with the model.
  108. inc_v1.op.run();
  109. });
  110. }
  111. /// <summary>
  112. /// https://databricks.com/tensorflow/variables
  113. /// </summary>
  114. [TestMethod]
  115. public void Add()
  116. {
  117. int result = 0;
  118. Tensor x = tf.Variable(10, name: "x");
  119. var init_op = tf.global_variables_initializer();
  120. using (var session = tf.Session())
  121. {
  122. session.run(init_op);
  123. for(int i = 0; i < 5; i++)
  124. {
  125. x = x + 1;
  126. result = session.run(x);
  127. print(result);
  128. }
  129. }
  130. Assert.AreEqual(15, result);
  131. }
  132. }
  133. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。