diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index d2ab44b2..cd7a3642 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -29,6 +29,9 @@ namespace Tensorflow public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) => state_ops.assign(@ref, value, validate_shape, use_locking, name); + public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) + => state_ops.assign(@ref, value, validate_shape, use_locking, name); + public void device(string device_name) => get_default_graph().device(device_name); diff --git a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs index 6c593929..0b09f783 100644 --- a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs +++ b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs @@ -33,5 +33,22 @@ namespace TensorFlowNET.UnitTest.Basics } } } + + [TestMethod] + public void Bug397() + { + // fix bug https://github.com/SciSharp/TensorFlow.NET/issues/397 + var W = tf.Variable(-1, name: "weight_" + 1, dtype: tf.float32); + var init = tf.global_variables_initializer(); + var reluEval = tf.nn.relu(W); + var nonZero = tf.assign(W, reluEval); + + using (var sess = tf.Session()) + { + sess.run(init); + float result = nonZero.eval(); + Assert.IsTrue(result == 0f); + } + } } } \ No newline at end of file