|
|
@@ -14,8 +14,19 @@ namespace TensorFlowNET.UnitTest |
|
|
|
{ |
|
|
|
var a = tf.constant(0.0); |
|
|
|
var b = 2.0 * a; |
|
|
|
var c = a + b; |
|
|
|
var g = tf.gradients(c, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b }); |
|
|
|
Assert.AreEqual(b.name, "mul:0"); |
|
|
|
Assert.AreEqual(b.op.inputs[0].name, "mul/x:0"); |
|
|
|
Assert.AreEqual(b.op.inputs[1].name, "Const:0"); |
|
|
|
|
|
|
|
var ys = a + b; |
|
|
|
Assert.AreEqual(ys.name, "add:0"); |
|
|
|
Assert.AreEqual(ys.op.inputs[0].name, "Const:0"); |
|
|
|
Assert.AreEqual(ys.op.inputs[1].name, "mul:0"); |
|
|
|
|
|
|
|
var xs = new Tensor[] { a, b }; |
|
|
|
var g = tf.gradients(ys, xs, stop_gradients: new Tensor[] { a, b }); |
|
|
|
Assert.AreEqual(g[0].name, "gradients/Fill:0"); |
|
|
|
Assert.AreEqual(g[1].name, "gradients/Fill:0"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |