|
@@ -51,29 +51,29 @@ namespace TensorFlowNET.UnitTest.ManagedAPI |
|
|
[TestMethod] |
|
|
[TestMethod] |
|
|
public void GradientSliceTest() |
|
|
public void GradientSliceTest() |
|
|
{ |
|
|
{ |
|
|
var X = tf.zeros(new TensorShape(10)); |
|
|
|
|
|
|
|
|
var X = tf.zeros(10); |
|
|
var W = tf.Variable(-0.06f, name: "weight"); |
|
|
var W = tf.Variable(-0.06f, name: "weight"); |
|
|
var b = tf.Variable(-0.73f, name: "bias"); |
|
|
var b = tf.Variable(-0.73f, name: "bias"); |
|
|
using var g = tf.GradientTape(); |
|
|
using var g = tf.GradientTape(); |
|
|
var pred = W * X + b; |
|
|
var pred = W * X + b; |
|
|
var test = tf.slice(pred, new[] { 0 }, pred.shape); |
|
|
var test = tf.slice(pred, new[] { 0 }, pred.shape); |
|
|
var gradients = g.gradient(test, (W, b)); |
|
|
var gradients = g.gradient(test, (W, b)); |
|
|
Assert.AreNotEqual(gradients.Item1, null); |
|
|
|
|
|
Assert.AreNotEqual(gradients.Item2, null); |
|
|
|
|
|
|
|
|
Assert.AreEqual((float)gradients.Item1, 0f); |
|
|
|
|
|
Assert.AreEqual((float)gradients.Item2, 10f); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
[TestMethod] |
|
|
[TestMethod] |
|
|
public void GradientConcatTest() |
|
|
public void GradientConcatTest() |
|
|
{ |
|
|
{ |
|
|
var X = tf.zeros(new TensorShape(10)); |
|
|
|
|
|
|
|
|
var X = tf.zeros(10); |
|
|
var W = tf.Variable(-0.06f, name: "weight"); |
|
|
var W = tf.Variable(-0.06f, name: "weight"); |
|
|
var b = tf.Variable(-0.73f, name: "bias"); |
|
|
var b = tf.Variable(-0.73f, name: "bias"); |
|
|
var test = tf.concat(new Tensor[] { W, b }, 0); |
|
|
var test = tf.concat(new Tensor[] { W, b }, 0); |
|
|
using var g = tf.GradientTape(); |
|
|
using var g = tf.GradientTape(); |
|
|
var pred = test[0] * X + test[1]; |
|
|
var pred = test[0] * X + test[1]; |
|
|
var gradients = g.gradient(pred, (W, b)); |
|
|
var gradients = g.gradient(pred, (W, b)); |
|
|
Assert.AreEqual((float)gradients.Item1, 0); |
|
|
|
|
|
Assert.AreEqual((float)gradients.Item2, 10); |
|
|
|
|
|
|
|
|
Assert.IsNull(gradients.Item1); |
|
|
|
|
|
Assert.IsNull(gradients.Item2); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |