From 13b1b63b4edc513bb16588dc9875196efb5642f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavel=20S=CC=8Cavara?= Date: Fri, 8 May 2020 01:23:14 +0200 Subject: [PATCH] fix NullReferenceException on session run with ITensorOrOperation signature --- src/TensorFlowNET.Core/Sessions/BaseSession.cs | 3 ++- test/TensorFlowNET.UnitTest/SessionTest.cs | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index b25e77e5..5fcdc547 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -65,7 +65,8 @@ namespace Tensorflow public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) { - return _run(fetche, feed_dict)[0]; + var results = _run(fetche, feed_dict); + return fetche is Tensor ? results[0] : null; } public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index ddb6fe5f..95d2d447 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -133,6 +133,17 @@ namespace TensorFlowNET.UnitTest } } } + + [TestMethod] + public void Autocast_Case0() + { + var sess = tf.Session().as_default(); + ITensorOrOperation operation = tf.global_variables_initializer(); + // the cast to ITensorOrOperation is essential for the test of this method signature + var ret = sess.run(operation); + + ret.Should().BeNull(); + } [TestMethod] public void Autocast_Case1()