Browse Source

Inconsistency in handling of DT_FLOAT and DT_DOUBLE types by gradient calculation

tags/v0.9
degtiadr 6 years ago
parent
commit
ee2bbbc101
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -368,7 +368,7 @@ namespace Tensorflow
if (y.dtype.is_complex()) if (y.dtype.is_complex())
throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})");
var shape = array_ops.shape(y); var shape = array_ops.shape(y);
var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}");
var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}");
var fill = gen_array_ops.fill(shape, constant); var fill = gen_array_ops.fill(shape, constant);
new_grad_ys.Add(fill); new_grad_ys.Add(fill);
} }


Loading…
Cancel
Save