From 9d71cad96ecb69cd83c2b113fc808b608fbd7875 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Thu, 14 Sep 2023 11:21:18 +0000 Subject: [PATCH] using and no IEnumerable --- .../ControlFlowTest/WhileContextTestCase.cs | 4 ++-- .../GradientTest/GradientTest.cs | 16 ++++++++------ .../PythonTest.cs | 22 +++++++++---------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs index 4dee6133..e93324f3 100644 --- a/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs @@ -24,13 +24,13 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest private void _testWhileContextHelper(int maximum_iterations) { // TODO: implement missing code dependencies - var sess = this.cached_session(); + using var sess = this.cached_session(); var i = constant_op.constant(0, name: "i"); var c = new Func(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c")); var b = new Func(x => math_ops.add(x, 1, name: "c")); //control_flow_ops.while_loop( // c, b, i , maximum_iterations: tf.constant(maximum_iterations)); - foreach (Operation op in sess.Single().graph.get_operations()) + foreach (Operation op in sess.graph.get_operations()) { var control_flow_context = op._get_control_flow_context(); /*if (control_flow_context != null) diff --git a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs index 0b4d79bb..099c1162 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -394,13 +394,15 @@ namespace TensorFlowNET.UnitTest.Gradient // Test that we differentiate both 'x' and 'y' correctly when x is a // predecessor of y. - var sess = self.cached_session().Single(); - var x = tf.constant(1.0); - var y = x * 2.0; - var z = y * 3.0; - var grads = tf.gradients(z, new[] { x, y }); - self.assertTrue(all(grads.Select(x => x != null))); - self.assertEqual(6.0, grads[0].eval()); + using (self.cached_session()) + { + var x = tf.constant(1.0); + var y = x * 2.0; + var z = y * 3.0; + var grads = tf.gradients(z, new[] { x, y }); + self.assertTrue(all(grads.Select(x => x != null))); + self.assertEqual(6.0, grads[0].eval()); + } } [Ignore("TODO")] diff --git a/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs b/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs index 90abc0cc..ccf59f5a 100644 --- a/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs @@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest } ///Returns a TensorFlow Session for use in executing tests. - public IEnumerable cached_session( + public Session cached_session( Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) { // This method behaves differently than self.session(): for performance reasons @@ -267,9 +267,8 @@ namespace TensorFlowNET.UnitTest { var sess = self._get_cached_session( graph, config, force_gpu, crash_if_inconsistent_args: true); - var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); - return cached; - + using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); + return cached; } } @@ -318,13 +317,12 @@ namespace TensorFlowNET.UnitTest return s.as_default(); } - private IEnumerable _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) + private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) { // Set the session and its graph to global default and constrain devices.""" - // if context.executing_eagerly(): - // yield None - // else: - { + if (tf.executing_eagerly()) + return null; + else { sess.graph.as_default(); sess.as_default(); { @@ -340,13 +338,13 @@ namespace TensorFlowNET.UnitTest using (sess.graph.device(gpu_name)) { yield return sess; }*/ - yield return sess; + return sess; } else if (use_gpu) - yield return sess; + return sess; else using (sess.graph.device("/device:CPU:0")) - yield return sess; + return sess; } }