diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index b25e77e5..5fcdc547 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -65,7 +65,8 @@ namespace Tensorflow public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) { - return _run(fetche, feed_dict)[0]; + var results = _run(fetche, feed_dict); + return fetche is Tensor ? results[0] : null; } public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index ddb6fe5f..95d2d447 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -133,6 +133,17 @@ namespace TensorFlowNET.UnitTest } } } + + [TestMethod] + public void Autocast_Case0() + { + var sess = tf.Session().as_default(); + ITensorOrOperation operation = tf.global_variables_initializer(); + // the cast to ITensorOrOperation is essential for the test of this method signature + var ret = sess.run(operation); + + ret.Should().BeNull(); + } [TestMethod] public void Autocast_Case1()