diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 3c6dac91..0d1e6c8a 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -368,7 +368,7 @@ namespace Tensorflow if (y.dtype.is_complex()) throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); var shape = array_ops.shape(y); - var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); + var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); var fill = gen_array_ops.fill(shape, constant); new_grad_ys.Add(fill); }