You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 3.1 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. using NumSharp.Core;
  7. namespace TensorFlowNET.UnitTest
  8. {
  9. [TestClass]
  10. public class VariableTest : Python
  11. {
  12. [TestMethod]
  13. public void Initializer()
  14. {
  15. var x = tf.Variable(10, name: "x");
  16. using (var session = tf.Session())
  17. {
  18. session.run(x.initializer);
  19. var result = session.run(x);
  20. Assert.AreEqual(10, (int)result);
  21. }
  22. }
  23. [TestMethod]
  24. public void StringVar()
  25. {
  26. var mammal1 = tf.Variable("Elephant", "var1", tf.chars);
  27. var mammal2 = tf.Variable("Tiger");
  28. }
  29. [TestMethod]
  30. public void ScalarVar()
  31. {
  32. var x = tf.constant(3, name: "x");
  33. var y = tf.Variable(x + 1, name: "y");
  34. var model = tf.global_variables_initializer();
  35. using (var session = tf.Session())
  36. {
  37. session.run(model);
  38. int result = session.run(y);
  39. Assert.AreEqual(result, 4);
  40. }
  41. }
  42. [TestMethod]
  43. public void Assign1()
  44. {
  45. with<Graph>(tf.Graph().as_default(), graph =>
  46. {
  47. var variable = tf.Variable(31, name: "tree");
  48. var init = tf.global_variables_initializer();
  49. var sess = tf.Session(graph);
  50. sess.run(init);
  51. var result = sess.run(variable);
  52. Assert.IsTrue((int)result == 31);
  53. var assign = variable.assign(12);
  54. result = sess.run(assign);
  55. Assert.IsTrue((int)result == 12);
  56. });
  57. }
  58. [TestMethod]
  59. public void Assign2()
  60. {
  61. var v1 = tf.Variable(10.0f, name: "v1"); //tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
  62. var inc_v1 = v1.assign(v1 + 1.0f);
  63. // Add an op to initialize the variables.
  64. var init_op = tf.global_variables_initializer();
  65. with<Session>(tf.Session(), sess =>
  66. {
  67. sess.run(init_op);
  68. // o some work with the model.
  69. inc_v1.op.run();
  70. });
  71. }
  72. /// <summary>
  73. /// https://databricks.com/tensorflow/variables
  74. /// </summary>
  75. [TestMethod]
  76. public void Add()
  77. {
  78. int result = 0;
  79. Tensor x = tf.Variable(10, name: "x");
  80. var init_op = tf.global_variables_initializer();
  81. using (var session = tf.Session())
  82. {
  83. session.run(init_op);
  84. for(int i = 0; i < 5; i++)
  85. {
  86. x = x + 1;
  87. result = session.run(x);
  88. print(result);
  89. }
  90. }
  91. Assert.AreEqual(15, result);
  92. }
  93. }
  94. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。