Browse Source

Fix data type in TapeTensor.OnesLike. #581

tags/v0.20
Oceania2018 5 years ago
parent
commit
a08275563e
4 changed files with 13 additions and 14 deletions
  1. +4
    -4
      src/TensorFlowNET.Core/Gradients/TapeTensor.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  4. +5
    -9
      test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs

+ 4
- 4
src/TensorFlowNET.Core/Gradients/TapeTensor.cs View File

@@ -22,10 +22,10 @@ namespace Tensorflow.Gradients

public long GetID() => id;

public Tensor ZerosLike(int[] shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT)
=> tf.zeros(shape == null ? new int[0] : shape, dtype: dtype);
public Tensor ZerosLike()
=> tf.zeros(shape: shape, dtype: dtype);

public Tensor OnesLike(int[] shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT)
=> tf.ones(shape == null ? new int[0] : shape, dtype: dtype);
public Tensor OnesLike()
=> tf.ones(shape: shape, dtype: dtype);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -609,7 +609,7 @@ namespace Tensorflow.Gradients
return tf_with(ops.control_dependencies(grads), delegate
{
x = math_ops.conj(x);
var y = constant_op.constant(2.0f, dtype: x.dtype);
var y = constant_op.constant(2.0, dtype: x.dtype);
return new Tensor[] { math_ops.multiply(grad, math_ops.multiply(x, y)) };
});
}


+ 3
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -119,6 +119,9 @@ namespace Tensorflow
{
switch (dtype)
{
case TF_DataType.TF_DOUBLE:
value = Convert.ToDouble(value);
break;
case TF_DataType.TF_FLOAT:
value = Convert.ToSingle(value);
break;


+ 5
- 9
test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.UnitTest.TF_API
[TestMethod]
public void GradientFloatTest()
{
var x = tf.Variable(3.0, dtype: TF_DataType.TF_FLOAT);
var x = tf.Variable(3.0, dtype: tf.float32);
using var tape = tf.GradientTape();
var y = tf.square(x);
var y_grad = tape.gradient(y, x);
@@ -22,26 +22,22 @@ namespace Tensorflow.UnitTest.TF_API

[TestMethod]
public void GradientDefaultTest()
{//error 1#: Variable default type
{
var x = tf.Variable(3.0);
using var tape = tf.GradientTape();
var y = tf.square(x);
var y_grad = tape.gradient(y, x);
Assert.AreEqual(9.0, (double)y);
}

[TestMethod]
public void GradientDoubleTest()
{//error 2#: Variable double type
var x = tf.Variable(3.0, dtype: TF_DataType.TF_DOUBLE);
{
var x = tf.Variable(3.0, dtype: tf.float64);
using var tape = tf.GradientTape();
var y = tf.square(x);
var y_grad = tape.gradient(y, x);
Assert.AreEqual(9.0, (double)y);
}





}
}

Loading…
Cancel
Save