|
|
@@ -222,14 +222,14 @@ namespace Tensorflow |
|
|
|
/// <returns> The reduced tensor.</returns> |
|
|
|
public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) |
|
|
|
{ |
|
|
|
with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => |
|
|
|
return with(ops.name_scope(name, "ReduceLogSumExp", new { input_tensor }), scope => |
|
|
|
{ |
|
|
|
var raw_max = reduce_max(input_tensor, axis, true); |
|
|
|
var my_max = array_ops.stop_gradient(array_ops.where(gen_math_ops.is_finite(raw_max), raw_max, array_ops.zeros_like(raw_max))); |
|
|
|
var result = gen_math_ops.log( |
|
|
|
reduce_sum( |
|
|
|
gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)), |
|
|
|
new Tensor(axis), |
|
|
|
axis[0], |
|
|
|
keepdims)); |
|
|
|
if (!keepdims) |
|
|
|
{ |
|
|
@@ -238,7 +238,6 @@ namespace Tensorflow |
|
|
|
result = gen_math_ops.add(result, my_max); |
|
|
|
return _may_reduce_to_scalar(keepdims, axis, result); |
|
|
|
}); |
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) |
|
|
@@ -295,13 +294,17 @@ namespace Tensorflow |
|
|
|
if (!common_shapes.has_fully_defined_shape(output) && |
|
|
|
!keepdims && |
|
|
|
axis == null) |
|
|
|
// We want set_shape to be reflected in the C API graph for when we run it. |
|
|
|
output.shape = new long[0]; |
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor _may_reduce_to_scalar(bool keepdims, int[] axis, Tensor output) |
|
|
|
{ |
|
|
|
output.shape = new long[0]; |
|
|
|
if (!common_shapes.has_fully_defined_shape(output) && |
|
|
|
!keepdims && |
|
|
|
axis == null) |
|
|
|
output.shape = new long[0]; |
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
@@ -323,7 +326,7 @@ namespace Tensorflow |
|
|
|
if (axis != null) |
|
|
|
{ |
|
|
|
// should return axis. or check before. |
|
|
|
return null; |
|
|
|
return ops.convert_to_tensor(axis, TF_DataType.TF_INT32); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|