You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Operations;
  8. namespace TensorFlowNET.UnitTest.ops_test
  9. {
  10. /// <summary>
  11. /// excerpt of tensorflow/python/framework/ops_test.py
  12. /// </summary>
  13. [TestClass]
  14. public class GraphTest : PythonTest
  15. {
  16. [TestInitialize]
  17. public void SetUp()
  18. {
  19. ops.reset_default_graph();
  20. }
  21. [TestCleanup]
  22. public void TearDown()
  23. {
  24. ops.reset_default_graph();
  25. }
  26. private void _AssertDefault(Graph expected) {
  27. Assert.AreSame(ops.get_default_graph(), expected);
  28. }
  29. [Ignore("Todo: Port")]
  30. [TestMethod]
  31. public void testResetDefaultGraphNesting()
  32. {
  33. /*
  34. def testResetDefaultGraphNesting(self):
  35. g0 = ops.Graph()
  36. with self.assertRaises(AssertionError):
  37. with g0.as_default():
  38. ops.reset_default_graph()
  39. */
  40. }
  41. [Ignore("Todo: Port")]
  42. [TestMethod]
  43. public void testGraphContextManagerCancelsEager()
  44. {
  45. /*
  46. def testGraphContextManagerCancelsEager(self):
  47. with context.eager_mode():
  48. with ops.Graph().as_default():
  49. self.assertFalse(context.executing_eagerly())
  50. */
  51. }
  52. [Ignore("Todo: Port")]
  53. [TestMethod]
  54. public void testGraphContextManager()
  55. {
  56. /*
  57. def testGraphContextManager(self):
  58. g0 = ops.Graph()
  59. with g0.as_default() as g1:
  60. self.assertIs(g0, g1)
  61. */
  62. }
  63. [Ignore("Todo: Port")]
  64. [TestMethod]
  65. public void testDefaultGraph()
  66. {
  67. /*
  68. def testDefaultGraph(self):
  69. orig = ops.get_default_graph()
  70. self._AssertDefault(orig)
  71. g0 = ops.Graph()
  72. self._AssertDefault(orig)
  73. context_manager_0 = g0.as_default()
  74. self._AssertDefault(orig)
  75. with context_manager_0 as g0:
  76. self._AssertDefault(g0)
  77. with ops.Graph().as_default() as g1:
  78. self._AssertDefault(g1)
  79. self._AssertDefault(g0)
  80. self._AssertDefault(orig)
  81. */
  82. }
  83. [Ignore("Todo: Port")]
  84. [TestMethod]
  85. public void testPreventFeeding()
  86. {
  87. /*
  88. def testPreventFeeding(self):
  89. g = ops.Graph()
  90. a = constant_op.constant(2.0)
  91. self.assertTrue(g.is_feedable(a))
  92. g.prevent_feeding(a)
  93. self.assertFalse(g.is_feedable(a))
  94. */
  95. }
  96. [Ignore("Todo: Port")]
  97. [TestMethod]
  98. public void testAsGraphElementConversions()
  99. {
  100. /*
  101. def testAsGraphElementConversions(self):
  102. class ConvertibleObj(object):
  103. def _as_graph_element(self):
  104. return "FloatOutput:0"
  105. class NonConvertibleObj(object):
  106. pass
  107. g = ops.Graph()
  108. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  109. self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
  110. with self.assertRaises(TypeError):
  111. g.as_graph_element(NonConvertibleObj())
  112. */
  113. }
  114. [Ignore("Todo: Port")]
  115. [TestMethod]
  116. public void testGarbageCollected()
  117. {
  118. /*
  119. # Regression test against creating custom __del__ functions in classes
  120. # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
  121. # cycles that require calling a __del__ method, because the __del__ method can
  122. # theoretically increase the object's refcount to "save" it from gc, and any
  123. # already-deleted objects in the cycle would have be to restored.)
  124. def testGarbageCollected(self):
  125. # Create a graph we can delete and a weak reference to monitor if it's gc'd
  126. g = ops.Graph()
  127. g_ref = weakref.ref(g)
  128. # Create some ops
  129. with g.as_default():
  130. a = constant_op.constant(2.0)
  131. b = constant_op.constant(3.0)
  132. c = math_ops.add(a, b)
  133. # Create a session we can delete
  134. with session.Session(graph=g) as sess:
  135. self.evaluate(c)
  136. # Delete all references and trigger gc
  137. del g
  138. del a
  139. del b
  140. del c
  141. del sess
  142. gc.collect()
  143. self.assertIsNone(g_ref())
  144. */
  145. }
  146. [Ignore("Todo: Port")]
  147. [TestMethod]
  148. public void testRunnableAfterInvalidShape()
  149. {
  150. /*
  151. def testRunnableAfterInvalidShape(self):
  152. with ops.Graph().as_default():
  153. with self.assertRaises(ValueError):
  154. math_ops.add([1, 2], [1, 2, 3])
  155. a = constant_op.constant(1)
  156. with session.Session() as sess:
  157. self.evaluate(a)
  158. */
  159. }
  160. [Ignore("Todo: Port")]
  161. [TestMethod]
  162. public void testRunnableAfterInvalidShapeWithKernelLabelMap()
  163. {
  164. /*
  165. def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
  166. g = ops.Graph()
  167. with g.as_default():
  168. with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
  169. with self.assertRaises(ValueError):
  170. test_ops.kernel_label_required(1)
  171. a = constant_op.constant(1)
  172. with session.Session() as sess:
  173. self.evaluate(a)
  174. */
  175. }
  176. }
  177. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。