Browse Source

GradientConcatTest

tags/v0.60-tf.numpy
MPnoy Haiping 4 years ago
parent
commit
bf9c1adac7
1 changed files with 10 additions and 8 deletions
  1. +10
    -8
      test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

+ 10
- 8
test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

@@ -65,15 +66,16 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestMethod]
public void GradientConcatTest()
{
var X = tf.zeros(10);
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");
var test = tf.concat(new Tensor[] { W, b }, 0);
var w1 = tf.Variable(new[] { new[] { 1f } });
var w2 = tf.Variable(new[] { new[] { 3f } });
using var g = tf.GradientTape();
var pred = test[0] * X + test[1];
var gradients = g.gradient(pred, (W, b));
Assert.IsNull(gradients.Item1);
Assert.IsNull(gradients.Item2);
var w = tf.concat(new Tensor[] { w1, w2 }, 0);
var x = tf.ones((1, 2));
var y = tf.reduce_sum(x, 1);
var r = tf.matmul(w, x);
var gradients = g.gradient(r, w);
Assert.AreEqual((float)gradients[0][0], 2f);
Assert.AreEqual((float)gradients[1][0], 2f);
}
}
}

Loading…
Cancel
Save