From a22e92de50e9aa6c68b9a07fc01e00a012efe333 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 25 Dec 2021 09:28:10 -0600 Subject: [PATCH] np.average #894 --- .../Implementation/NumPyImpl.Statistics.cs | 31 +++++++++ .../NumPy/NumPy.Statistics.cs | 4 ++ src/TensorFlowNET.Core/NumPy/NumPyUtils.cs | 19 ++++++ src/TensorFlowNET.Core/Operations/math_ops.cs | 66 +++++++++++++++++++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- .../NumPy/Statistics.Test.cs | 32 +++++++++ 6 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs create mode 100644 src/TensorFlowNET.Core/NumPy/NumPyUtils.cs create mode 100644 test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs new file mode 100644 index 00000000..990c2ad6 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; + +namespace Tensorflow.NumPy +{ + public partial class NumPyImpl + { + public NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) + { + var dtype = NumPyUtils.GetResultType(a.dtype, np.float64); + if(weights is null) + { + var tensorA = math_ops.cast(a, dtype); + var nd = math_ops.reduce_mean(tensorA, axis); + return new NDArray(nd); + } + else + { + var tensorW = math_ops.cast(weights, dtype); + if(a.rank != weights.rank) + { + var weights_sum = math_ops.reduce_sum(tensorW); + var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } }); + var avg = math_ops.tensordot(a, weights, axes) / weights_sum; + } + + throw new NotImplementedException(""); + } + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs index 806d38b2..5d86b1b3 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs @@ -14,5 +14,9 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); + + [AutoNumPy] + public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) + => tf.numpy.average(a, axis: axis, weights: weights, returned: returned); } } diff --git a/src/TensorFlowNET.Core/NumPy/NumPyUtils.cs b/src/TensorFlowNET.Core/NumPy/NumPyUtils.cs new file mode 100644 index 00000000..35356603 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NumPyUtils.cs @@ -0,0 +1,19 @@ +using System; +using System.Text; + +namespace Tensorflow.NumPy +{ + internal class NumPyUtils + { + public static TF_DataType GetResultType(params TF_DataType[] dtypes) + { + var resultDType = dtypes[0]; + for(int i = 1; i < dtypes.Length; i++) + { + if (dtypes[i].get_datatype_size() > resultDType.get_datatype_size()) + resultDType = dtypes[i]; + } + return resultDType; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 5657fafa..8a058ea4 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -929,6 +929,72 @@ namespace Tensorflow throw new NotImplementedException("tensordot"); } + public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null) + { + Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) + { + if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List), typeof(Tuple)))) + { + var shape_a = a.shape.dims; + + // axes + int iter = 0; + foreach (int i in axes) + { + if (i >= 0) + axes[0 + iter] = i; + else + axes[0 + iter] = i + len(shape_a); + iter++; + } + + // free + int[] free = { }; + iter = 0; + foreach (int i in Enumerable.Range(0, len(axes))) + if (!Array.Exists(axes, i => i == i)) + free[free.Length] = i; + + // free_dims + int[] free_dims = { }; + foreach (int i in free) + free_dims[free_dims.Length] = (int)shape_a[i]; + + int prod_free = (int)np.prod(free_dims); + + // prod_axes + int[] prod_axes_pre = { }; + foreach (int i in axes) + prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; + int prod_axes = (int)np.prod(prod_axes_pre); + + // perm + Tensor perm; + if (flipped) + perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); + else + perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) + + ops.convert_to_tensor(list(axes)); + + // new_shape + Shape new_shape; + if (flipped) + new_shape = new Shape(new int[] { prod_axes, prod_free }); + else + new_shape = new Shape(new int[] { prod_free, prod_axes }); + } + + throw new NotImplementedException("_tensordot_reshape"); + } + + return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope => + { + name = scope; + var (a_axes, b_axes) = (axes[0], axes[1]); + return x; + }); + } + public static Tensor truediv(Tensor x, Tensor y, string name = null) => _truediv_python3(x, y, name); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 19f91961..e9ab81a7 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -78,7 +78,7 @@ namespace Tensorflow /// /// The name of the device on which this tensor will be produced, or null. /// - public virtual string Device => op.Device; + public virtual string Device => op?.Device; public long[] dims => shape.dims; /// diff --git a/test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs new file mode 100644 index 00000000..42005b15 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs @@ -0,0 +1,32 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/routines.statistics.html + /// + [TestClass] + public class StatisticsTest : EagerModeTestBase + { + [TestMethod] + public void average() + { + var data = np.arange(1, 5); + var avg = np.average(data); + Assert.AreEqual(avg, 2.5); + + data = np.arange(6).reshape((3, 2)); + avg = np.average(data, axis: 1); + assertAllEqual(avg.ToArray(), new[] { 0.5, 2.5, 4.5 }); + + // avg = np.average(data, axis: 1, weights: new[] { 1.0 / 4, 3.0 / 4 }); + // assertAllEqual(avg.ToArray(), new[] { 0.75, 2.75, 4.75 }); + } + } +}