You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. namespace TensorFlowNET.UnitTest.ops_test
  4. {
  5. /// <summary>
  6. /// excerpt of tensorflow/python/framework/ops_test.py
  7. /// </summary>
  8. [TestClass]
  9. public class GraphTest : PythonTest
  10. {
  11. [TestInitialize]
  12. public void SetUp()
  13. {
  14. ops.reset_default_graph();
  15. }
  16. [TestCleanup]
  17. public void TearDown()
  18. {
  19. ops.reset_default_graph();
  20. }
  21. private void _AssertDefault(Graph expected) {
  22. Assert.AreSame(ops.get_default_graph(), expected);
  23. }
  24. [Ignore("Todo: Port")]
  25. [TestMethod]
  26. public void testResetDefaultGraphNesting()
  27. {
  28. /*
  29. def testResetDefaultGraphNesting(self):
  30. g0 = ops.Graph()
  31. with self.assertRaises(AssertionError):
  32. with g0.as_default():
  33. ops.reset_default_graph()
  34. */
  35. }
  36. [Ignore("Todo: Port")]
  37. [TestMethod]
  38. public void testGraphContextManagerCancelsEager()
  39. {
  40. /*
  41. def testGraphContextManagerCancelsEager(self):
  42. with context.eager_mode():
  43. with ops.Graph().as_default():
  44. self.assertFalse(context.executing_eagerly())
  45. */
  46. }
  47. [Ignore("Todo: Port")]
  48. [TestMethod]
  49. public void testGraphContextManager()
  50. {
  51. /*
  52. def testGraphContextManager(self):
  53. g0 = ops.Graph()
  54. with g0.as_default() as g1:
  55. self.assertIs(g0, g1)
  56. */
  57. }
  58. [Ignore("Todo: Port")]
  59. [TestMethod]
  60. public void testDefaultGraph()
  61. {
  62. /*
  63. def testDefaultGraph(self):
  64. orig = ops.get_default_graph()
  65. self._AssertDefault(orig)
  66. g0 = ops.Graph()
  67. self._AssertDefault(orig)
  68. context_manager_0 = g0.as_default()
  69. self._AssertDefault(orig)
  70. with context_manager_0 as g0:
  71. self._AssertDefault(g0)
  72. with ops.Graph().as_default() as g1:
  73. self._AssertDefault(g1)
  74. self._AssertDefault(g0)
  75. self._AssertDefault(orig)
  76. */
  77. }
  78. [Ignore("Todo: Port")]
  79. [TestMethod]
  80. public void testPreventFeeding()
  81. {
  82. /*
  83. def testPreventFeeding(self):
  84. g = ops.Graph()
  85. a = constant_op.constant(2.0)
  86. self.assertTrue(g.is_feedable(a))
  87. g.prevent_feeding(a)
  88. self.assertFalse(g.is_feedable(a))
  89. */
  90. }
  91. [Ignore("Todo: Port")]
  92. [TestMethod]
  93. public void testAsGraphElementConversions()
  94. {
  95. /*
  96. def testAsGraphElementConversions(self):
  97. class ConvertibleObj(object):
  98. def _as_graph_element(self):
  99. return "FloatOutput:0"
  100. class NonConvertibleObj(object):
  101. pass
  102. g = ops.Graph()
  103. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  104. self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
  105. with self.assertRaises(TypeError):
  106. g.as_graph_element(NonConvertibleObj())
  107. */
  108. }
  109. [Ignore("Todo: Port")]
  110. [TestMethod]
  111. public void testGarbageCollected()
  112. {
  113. /*
  114. # Regression test against creating custom __del__ functions in classes
  115. # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
  116. # cycles that require calling a __del__ method, because the __del__ method can
  117. # theoretically increase the object's refcount to "save" it from gc, and any
  118. # already-deleted objects in the cycle would have be to restored.)
  119. def testGarbageCollected(self):
  120. # Create a graph we can delete and a weak reference to monitor if it's gc'd
  121. g = ops.Graph()
  122. g_ref = weakref.ref(g)
  123. # Create some ops
  124. with g.as_default():
  125. a = constant_op.constant(2.0)
  126. b = constant_op.constant(3.0)
  127. c = math_ops.add(a, b)
  128. # Create a session we can delete
  129. with session.Session(graph=g) as sess:
  130. self.evaluate(c)
  131. # Delete all references and trigger gc
  132. del g
  133. del a
  134. del b
  135. del c
  136. del sess
  137. gc.collect()
  138. self.assertIsNone(g_ref())
  139. */
  140. }
  141. [Ignore("Todo: Port")]
  142. [TestMethod]
  143. public void testRunnableAfterInvalidShape()
  144. {
  145. /*
  146. def testRunnableAfterInvalidShape(self):
  147. with ops.Graph().as_default():
  148. with self.assertRaises(ValueError):
  149. math_ops.add([1, 2], [1, 2, 3])
  150. a = constant_op.constant(1)
  151. with session.Session() as sess:
  152. self.evaluate(a)
  153. */
  154. }
  155. [Ignore("Todo: Port")]
  156. [TestMethod]
  157. public void testRunnableAfterInvalidShapeWithKernelLabelMap()
  158. {
  159. /*
  160. def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
  161. g = ops.Graph()
  162. with g.as_default():
  163. with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
  164. with self.assertRaises(ValueError):
  165. test_ops.kernel_label_required(1)
  166. a = constant_op.constant(1)
  167. with session.Session() as sess:
  168. self.evaluate(a)
  169. */
  170. }
  171. }
  172. }