From 387ae4c35600a35710e4a43938558ba29abf5911 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 17 Nov 2019 23:40:57 -0600 Subject: [PATCH] overload tf.reduce_mean --- src/TensorFlowNET.Core/APIs/tf.math.cs | 2 +- .../Operations/array_ops.cs | 10 +++++++++ src/TensorFlowNET.Core/Operations/math_ops.cs | 22 +++++++++++++++---- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 6cb43980..790e391e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -474,7 +474,7 @@ namespace Tensorflow public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); - public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) + public Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) => math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); public Tensor round(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 04964069..c487f478 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -228,6 +228,16 @@ namespace Tensorflow public static Tensor rank(Tensor input, string name = null) => rank_internal(input, name, optimize: true); + public static Tensor rank(Tensor[] inputs, string name = null) + { + return tf_with(ops.name_scope(name, "Rank", new { inputs }), scope => + { + name = scope; + var input_tensor = ops.convert_to_tensor(inputs); + return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name); + }); + } + public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) { return tf_with(ops.name_scope(name, "Rank", new List { input }), scope => diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 848a89cd..bb8d7134 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -219,10 +219,19 @@ namespace Tensorflow } } - public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) + public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) { - var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); + if(axis == null) + { + var r = _ReductionDims(input_tensors, axis); + var m = gen_math_ops.mean(input_tensors, r, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + else + { + var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); + } } /// @@ -492,7 +501,7 @@ namespace Tensorflow return output; } - private static Tensor _may_reduce_to_scalar(bool keepdims, int axis, Tensor output) + private static Tensor _may_reduce_to_scalar(bool keepdims, int? axis, Tensor output) { return output; } @@ -515,6 +524,11 @@ namespace Tensorflow return axis; } + private static Tensor _ReductionDims(Tensor[] x, int? axis = null, string name = null) + { + return range(0, array_ops.rank(x)); + } + private static Tensor _ReductionDims(Tensor x, int[] axis) { if (axis != null)