You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using Tensorflow.UnitTest;
  4. namespace TensorFlowNET.UnitTest.ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class GraphTest : GraphModeTestBase
  11. {
  12. [TestInitialize]
  13. public void SetUp()
  14. {
  15. ops.reset_default_graph();
  16. }
  17. [TestCleanup]
  18. public void TearDown()
  19. {
  20. ops.reset_default_graph();
  21. }
  22. private void _AssertDefault(Graph expected) {
  23. Assert.AreSame(ops.get_default_graph(), expected);
  24. }
  25. [Ignore("Todo: Port")]
  26. [TestMethod]
  27. public void testResetDefaultGraphNesting()
  28. {
  29. /*
  30. def testResetDefaultGraphNesting(self):
  31. g0 = ops.Graph()
  32. with self.assertRaises(AssertionError):
  33. with g0.as_default():
  34. ops.reset_default_graph()
  35. */
  36. }
  37. [Ignore("Todo: Port")]
  38. [TestMethod]
  39. public void testGraphContextManagerCancelsEager()
  40. {
  41. /*
  42. def testGraphContextManagerCancelsEager(self):
  43. with context.eager_mode():
  44. with ops.Graph().as_default():
  45. self.assertFalse(context.executing_eagerly())
  46. */
  47. }
  48. [Ignore("Todo: Port")]
  49. [TestMethod]
  50. public void testGraphContextManager()
  51. {
  52. /*
  53. def testGraphContextManager(self):
  54. g0 = ops.Graph()
  55. with g0.as_default() as g1:
  56. self.assertIs(g0, g1)
  57. */
  58. }
  59. [Ignore("Todo: Port")]
  60. [TestMethod]
  61. public void testDefaultGraph()
  62. {
  63. /*
  64. def testDefaultGraph(self):
  65. orig = ops.get_default_graph()
  66. self._AssertDefault(orig)
  67. g0 = ops.Graph()
  68. self._AssertDefault(orig)
  69. context_manager_0 = g0.as_default()
  70. self._AssertDefault(orig)
  71. with context_manager_0 as g0:
  72. self._AssertDefault(g0)
  73. with ops.Graph().as_default() as g1:
  74. self._AssertDefault(g1)
  75. self._AssertDefault(g0)
  76. self._AssertDefault(orig)
  77. */
  78. }
  79. [Ignore("Todo: Port")]
  80. [TestMethod]
  81. public void testPreventFeeding()
  82. {
  83. /*
  84. def testPreventFeeding(self):
  85. g = ops.Graph()
  86. a = constant_op.constant(2.0)
  87. self.assertTrue(g.is_feedable(a))
  88. g.prevent_feeding(a)
  89. self.assertFalse(g.is_feedable(a))
  90. */
  91. }
  92. [Ignore("Todo: Port")]
  93. [TestMethod]
  94. public void testAsGraphElementConversions()
  95. {
  96. /*
  97. def testAsGraphElementConversions(self):
  98. class ConvertibleObj(object):
  99. def _as_graph_element(self):
  100. return "FloatOutput:0"
  101. class NonConvertibleObj(object):
  102. pass
  103. g = ops.Graph()
  104. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  105. self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
  106. with self.assertRaises(TypeError):
  107. g.as_graph_element(NonConvertibleObj())
  108. */
  109. }
  110. [Ignore("Todo: Port")]
  111. [TestMethod]
  112. public void testGarbageCollected()
  113. {
  114. /*
  115. # Regression test against creating custom __del__ functions in classes
  116. # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
  117. # cycles that require calling a __del__ method, because the __del__ method can
  118. # theoretically increase the object's refcount to "save" it from gc, and any
  119. # already-deleted objects in the cycle would have be to restored.)
  120. def testGarbageCollected(self):
  121. # Create a graph we can delete and a weak reference to monitor if it's gc'd
  122. g = ops.Graph()
  123. g_ref = weakref.ref(g)
  124. # Create some ops
  125. with g.as_default():
  126. a = constant_op.constant(2.0)
  127. b = constant_op.constant(3.0)
  128. c = math_ops.add(a, b)
  129. # Create a session we can delete
  130. with session.Session(graph=g) as sess:
  131. self.evaluate(c)
  132. # Delete all references and trigger gc
  133. del g
  134. del a
  135. del b
  136. del c
  137. del sess
  138. gc.collect()
  139. self.assertIsNone(g_ref())
  140. */
  141. }
  142. [Ignore("Todo: Port")]
  143. [TestMethod]
  144. public void testRunnableAfterInvalidShape()
  145. {
  146. /*
  147. def testRunnableAfterInvalidShape(self):
  148. with ops.Graph().as_default():
  149. with self.assertRaises(ValueError):
  150. math_ops.add([1, 2], [1, 2, 3])
  151. a = constant_op.constant(1)
  152. with session.Session() as sess:
  153. self.evaluate(a)
  154. */
  155. }
  156. [Ignore("Todo: Port")]
  157. [TestMethod]
  158. public void testRunnableAfterInvalidShapeWithKernelLabelMap()
  159. {
  160. /*
  161. def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
  162. g = ops.Graph()
  163. with g.as_default():
  164. with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
  165. with self.assertRaises(ValueError):
  166. test_ops.kernel_label_required(1)
  167. a = constant_op.constant(1)
  168. with session.Session() as sess:
  169. self.evaluate(a)
  170. */
  171. }
  172. }
  173. }