You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using Tensorflow.UnitTest;
  4. namespace TensorFlowNET.UnitTest.ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class GraphTest : GraphModeTestBase
  11. {
  12. [TestInitialize]
  13. public void SetUp()
  14. {
  15. ops.reset_default_graph();
  16. }
  17. [TestCleanup]
  18. public void TearDown()
  19. {
  20. ops.reset_default_graph();
  21. }
  22. private void _AssertDefault(Graph expected)
  23. {
  24. Assert.AreSame(ops.get_default_graph(), expected);
  25. }
  26. [Ignore("Todo: Port")]
  27. [TestMethod]
  28. public void testResetDefaultGraphNesting()
  29. {
  30. /*
  31. def testResetDefaultGraphNesting(self):
  32. g0 = ops.Graph()
  33. with self.assertRaises(AssertionError):
  34. with g0.as_default():
  35. ops.reset_default_graph()
  36. */
  37. }
  38. [Ignore("Todo: Port")]
  39. [TestMethod]
  40. public void testGraphContextManagerCancelsEager()
  41. {
  42. /*
  43. def testGraphContextManagerCancelsEager(self):
  44. with context.eager_mode():
  45. with ops.Graph().as_default():
  46. self.assertFalse(context.executing_eagerly())
  47. */
  48. }
  49. [Ignore("Todo: Port")]
  50. [TestMethod]
  51. public void testGraphContextManager()
  52. {
  53. /*
  54. def testGraphContextManager(self):
  55. g0 = ops.Graph()
  56. with g0.as_default() as g1:
  57. self.assertIs(g0, g1)
  58. */
  59. }
  60. [Ignore("Todo: Port")]
  61. [TestMethod]
  62. public void testDefaultGraph()
  63. {
  64. /*
  65. def testDefaultGraph(self):
  66. orig = ops.get_default_graph()
  67. self._AssertDefault(orig)
  68. g0 = ops.Graph()
  69. self._AssertDefault(orig)
  70. context_manager_0 = g0.as_default()
  71. self._AssertDefault(orig)
  72. with context_manager_0 as g0:
  73. self._AssertDefault(g0)
  74. with ops.Graph().as_default() as g1:
  75. self._AssertDefault(g1)
  76. self._AssertDefault(g0)
  77. self._AssertDefault(orig)
  78. */
  79. }
  80. [Ignore("Todo: Port")]
  81. [TestMethod]
  82. public void testPreventFeeding()
  83. {
  84. /*
  85. def testPreventFeeding(self):
  86. g = ops.Graph()
  87. a = constant_op.constant(2.0)
  88. self.assertTrue(g.is_feedable(a))
  89. g.prevent_feeding(a)
  90. self.assertFalse(g.is_feedable(a))
  91. */
  92. }
  93. [Ignore("Todo: Port")]
  94. [TestMethod]
  95. public void testAsGraphElementConversions()
  96. {
  97. /*
  98. def testAsGraphElementConversions(self):
  99. class ConvertibleObj(object):
  100. def _as_graph_element(self):
  101. return "FloatOutput:0"
  102. class NonConvertibleObj(object):
  103. pass
  104. g = ops.Graph()
  105. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  106. self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
  107. with self.assertRaises(TypeError):
  108. g.as_graph_element(NonConvertibleObj())
  109. */
  110. }
  111. [Ignore("Todo: Port")]
  112. [TestMethod]
  113. public void testGarbageCollected()
  114. {
  115. /*
  116. # Regression test against creating custom __del__ functions in classes
  117. # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
  118. # cycles that require calling a __del__ method, because the __del__ method can
  119. # theoretically increase the object's refcount to "save" it from gc, and any
  120. # already-deleted objects in the cycle would have be to restored.)
  121. def testGarbageCollected(self):
  122. # Create a graph we can delete and a weak reference to monitor if it's gc'd
  123. g = ops.Graph()
  124. g_ref = weakref.ref(g)
  125. # Create some ops
  126. with g.as_default():
  127. a = constant_op.constant(2.0)
  128. b = constant_op.constant(3.0)
  129. c = math_ops.add(a, b)
  130. # Create a session we can delete
  131. with session.Session(graph=g) as sess:
  132. self.evaluate(c)
  133. # Delete all references and trigger gc
  134. del g
  135. del a
  136. del b
  137. del c
  138. del sess
  139. gc.collect()
  140. self.assertIsNone(g_ref())
  141. */
  142. }
  143. [Ignore("Todo: Port")]
  144. [TestMethod]
  145. public void testRunnableAfterInvalidShape()
  146. {
  147. /*
  148. def testRunnableAfterInvalidShape(self):
  149. with ops.Graph().as_default():
  150. with self.assertRaises(ValueError):
  151. math_ops.add([1, 2], [1, 2, 3])
  152. a = constant_op.constant(1)
  153. with session.Session() as sess:
  154. self.evaluate(a)
  155. */
  156. }
  157. [Ignore("Todo: Port")]
  158. [TestMethod]
  159. public void testRunnableAfterInvalidShapeWithKernelLabelMap()
  160. {
  161. /*
  162. def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
  163. g = ops.Graph()
  164. with g.as_default():
  165. with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
  166. with self.assertRaises(ValueError):
  167. test_ops.kernel_label_required(1)
  168. a = constant_op.constant(1)
  169. with session.Session() as sess:
  170. self.evaluate(a)
  171. */
  172. }
  173. }
  174. }