Browse Source

overload tf.reduce_mean

tags/v0.13
Oceania2018 5 years ago
parent
commit
387ae4c356
3 changed files with 29 additions and 5 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +18
    -4
      src/TensorFlowNET.Core/Operations/math_ops.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -474,7 +474,7 @@ namespace Tensorflow
public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);

public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
public Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name);

public Tensor round(Tensor x, string name = null)


+ 10
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -228,6 +228,16 @@ namespace Tensorflow
public static Tensor rank(Tensor input, string name = null)
=> rank_internal(input, name, optimize: true);

public static Tensor rank(Tensor[] inputs, string name = null)
{
return tf_with(ops.name_scope(name, "Rank", new { inputs }), scope =>
{
name = scope;
var input_tensor = ops.convert_to_tensor(inputs);
return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name);
});
}

public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true)
{
return tf_with(ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>


+ 18
- 4
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -219,10 +219,19 @@ namespace Tensorflow
}
}

public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null)
{
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
if(axis == null)
{
var r = _ReductionDims(input_tensors, axis);
var m = gen_math_ops.mean(input_tensors, r, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}
else
{
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}
}

/// <summary>
@@ -492,7 +501,7 @@ namespace Tensorflow
return output;
}

private static Tensor _may_reduce_to_scalar(bool keepdims, int axis, Tensor output)
private static Tensor _may_reduce_to_scalar(bool keepdims, int? axis, Tensor output)
{
return output;
}
@@ -515,6 +524,11 @@ namespace Tensorflow
return axis;
}

private static Tensor _ReductionDims(Tensor[] x, int? axis = null, string name = null)
{
return range(0, array_ops.rank(x));
}

private static Tensor _ReductionDims(Tensor x, int[] axis)
{
if (axis != null)


Loading…
Cancel
Save