diff --git a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs index 2396ae25..6999bee7 100644 --- a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs +++ b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs @@ -22,10 +22,10 @@ namespace Tensorflow.Gradients public long GetID() => id; - public Tensor ZerosLike(int[] shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) - => tf.zeros(shape == null ? new int[0] : shape, dtype: dtype); + public Tensor ZerosLike() + => tf.zeros(shape: shape, dtype: dtype); - public Tensor OnesLike(int[] shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) - => tf.ones(shape == null ? new int[0] : shape, dtype: dtype); + public Tensor OnesLike() + => tf.ones(shape: shape, dtype: dtype); } } diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index c40efc33..0d6afbc9 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -609,7 +609,7 @@ namespace Tensorflow.Gradients return tf_with(ops.control_dependencies(grads), delegate { x = math_ops.conj(x); - var y = constant_op.constant(2.0f, dtype: x.dtype); + var y = constant_op.constant(2.0, dtype: x.dtype); return new Tensor[] { math_ops.multiply(grad, math_ops.multiply(x, y)) }; }); } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 369e0d26..0149ae1b 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -119,6 +119,9 @@ namespace Tensorflow { switch (dtype) { + case TF_DataType.TF_DOUBLE: + value = Convert.ToDouble(value); + break; case TF_DataType.TF_FLOAT: value = Convert.ToSingle(value); break; diff --git a/test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs b/test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs index 0f91f93f..e34fce17 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs @@ -13,7 +13,7 @@ namespace Tensorflow.UnitTest.TF_API [TestMethod] public void GradientFloatTest() { - var x = tf.Variable(3.0, dtype: TF_DataType.TF_FLOAT); + var x = tf.Variable(3.0, dtype: tf.float32); using var tape = tf.GradientTape(); var y = tf.square(x); var y_grad = tape.gradient(y, x); @@ -22,26 +22,22 @@ namespace Tensorflow.UnitTest.TF_API [TestMethod] public void GradientDefaultTest() - {//error 1#: Variable default type + { var x = tf.Variable(3.0); using var tape = tf.GradientTape(); var y = tf.square(x); var y_grad = tape.gradient(y, x); Assert.AreEqual(9.0, (double)y); } + [TestMethod] public void GradientDoubleTest() - {//error 2#: Variable double type - var x = tf.Variable(3.0, dtype: TF_DataType.TF_DOUBLE); + { + var x = tf.Variable(3.0, dtype: tf.float64); using var tape = tf.GradientTape(); var y = tf.square(x); var y_grad = tape.gradient(y, x); Assert.AreEqual(9.0, (double)y); } - - - - - } }