From ee2bbbc1019c27c67b0cef410672a7180c82ab46 Mon Sep 17 00:00:00 2001 From: degtiadr Date: Thu, 30 May 2019 22:13:38 +0200 Subject: [PATCH] Inconsistency in handling of DT_FLOAT and DT_DOUBLE types by gradient calculation --- src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 3c6dac91..0d1e6c8a 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -368,7 +368,7 @@ namespace Tensorflow if (y.dtype.is_complex()) throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); var shape = array_ops.shape(y); - var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); + var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); var fill = gen_array_ops.fill(shape, constant); new_grad_ys.Add(fill); }