Browse Source

Gradient Variable Test

tags/v0.20
pepure Haiping 5 years ago
parent
commit
a3672abdbe
1 changed files with 47 additions and 0 deletions
  1. +47
    -0
      test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs

+ 47
- 0
test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs View File

@@ -0,0 +1,47 @@
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

namespace Tensorflow.UnitTest.TF_API
{
[TestClass]
public class GradientTest
{
[TestMethod]
public void GradientFloatTest()
{
var x = tf.Variable(3.0, dtype: TF_DataType.TF_FLOAT);
using var tape = tf.GradientTape();
var y = tf.square(x);
var y_grad = tape.gradient(y, x);
Assert.AreEqual(9.0f, (float)y);
}

[TestMethod]
public void GradientDefaultTest()
{//error 1#: Variable default type
var x = tf.Variable(3.0);
using var tape = tf.GradientTape();
var y = tf.square(x);
var y_grad = tape.gradient(y, x);
Assert.AreEqual(9.0, (double)y);
}
[TestMethod]
public void GradientDoubleTest()
{//error 2#: Variable double type
var x = tf.Variable(3.0, dtype: TF_DataType.TF_DOUBLE);
using var tape = tf.GradientTape();
var y = tf.square(x);
var y_grad = tape.gradient(y, x);
Assert.AreEqual(9.0, (double)y);
}





}
}

Loading…
Cancel
Save