Browse Source

fix bug #397

tags/v0.12
Oceania2018 6 years ago
parent
commit
acf189fbfb
2 changed files with 20 additions and 0 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.ops.cs
  2. +17
    -0
      test/TensorFlowNET.UnitTest/Basics/AssignTests.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.ops.cs View File

@@ -29,6 +29,9 @@ namespace Tensorflow
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name); => state_ops.assign(@ref, value, validate_shape, use_locking, name);


public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);

public void device(string device_name) public void device(string device_name)
=> get_default_graph().device(device_name); => get_default_graph().device(device_name);




+ 17
- 0
test/TensorFlowNET.UnitTest/Basics/AssignTests.cs View File

@@ -33,5 +33,22 @@ namespace TensorFlowNET.UnitTest.Basics
} }
} }
} }

[TestMethod]
public void Bug397()
{
// fix bug https://github.com/SciSharp/TensorFlow.NET/issues/397
var W = tf.Variable(-1, name: "weight_" + 1, dtype: tf.float32);
var init = tf.global_variables_initializer();
var reluEval = tf.nn.relu(W);
var nonZero = tf.assign(W, reluEval);

using (var sess = tf.Session())
{
sess.run(init);
float result = nonZero.eval();
Assert.IsTrue(result == 0f);
}
}
} }
} }

Loading…
Cancel
Save