|
@@ -368,7 +368,7 @@ namespace Tensorflow |
|
|
if (y.dtype.is_complex()) |
|
|
if (y.dtype.is_complex()) |
|
|
throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); |
|
|
throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); |
|
|
var shape = array_ops.shape(y); |
|
|
var shape = array_ops.shape(y); |
|
|
var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); |
|
|
|
|
|
|
|
|
var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); |
|
|
var fill = gen_array_ops.fill(shape, constant); |
|
|
var fill = gen_array_ops.fill(shape, constant); |
|
|
new_grad_ys.Add(fill); |
|
|
new_grad_ys.Add(fill); |
|
|
} |
|
|
} |
|
|