Browse Source

GradientConcatTest

tags/v0.60-tf.numpy
MPnoy Haiping 4 years ago
parent
commit
cd611f1750
1 changed files with 16 additions and 1 deletions
  1. +16
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

+ 16
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.ManagedAPI namespace TensorFlowNET.UnitTest.ManagedAPI
@@ -50,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestMethod] [TestMethod]
public void GradientSliceTest() public void GradientSliceTest()
{ {
var X = tf.zeros(new Tensorflow.TensorShape(10));
var X = tf.zeros(new TensorShape(10));
var W = tf.Variable(-0.06f, name: "weight"); var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias"); var b = tf.Variable(-0.73f, name: "bias");
using var g = tf.GradientTape(); using var g = tf.GradientTape();
@@ -60,5 +61,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreNotEqual(gradients.Item1, null); Assert.AreNotEqual(gradients.Item1, null);
Assert.AreNotEqual(gradients.Item2, null); Assert.AreNotEqual(gradients.Item2, null);
} }

[TestMethod]
public void GradientConcatTest()
{
var X = tf.zeros(new TensorShape(10));
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");
var test = tf.concat(new Tensor[] { W, b }, 0);
using var g = tf.GradientTape();
var pred = test[0] * X + test[1];
var gradients = g.gradient(pred, (W, b));
Assert.AreEqual((float)gradients.Item1, 0);
Assert.AreEqual((float)gradients.Item2, 10);
}
} }
} }

Loading…
Cancel
Save