Browse Source

np.average #894

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
a22e92de50
6 changed files with 153 additions and 1 deletions
  1. +31
    -0
      src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs
  2. +4
    -0
      src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs
  3. +19
    -0
      src/TensorFlowNET.Core/NumPy/NumPyUtils.cs
  4. +66
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +32
    -0
      test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs

+ 31
- 0
src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;

namespace Tensorflow.NumPy
{
public partial class NumPyImpl
{
public NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false)
{
var dtype = NumPyUtils.GetResultType(a.dtype, np.float64);
if(weights is null)
{
var tensorA = math_ops.cast(a, dtype);
var nd = math_ops.reduce_mean(tensorA, axis);
return new NDArray(nd);
}
else
{
var tensorW = math_ops.cast(weights, dtype);
if(a.rank != weights.rank)
{
var weights_sum = math_ops.reduce_sum(tensorW);
var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } });
var avg = math_ops.tensordot(a, weights, axes) / weights_sum;
}
throw new NotImplementedException("");
}
}
}
}

+ 4
- 0
src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs View File

@@ -14,5 +14,9 @@ namespace Tensorflow.NumPy

[AutoNumPy]
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis));

[AutoNumPy]
public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false)
=> tf.numpy.average(a, axis: axis, weights: weights, returned: returned);
}
}

+ 19
- 0
src/TensorFlowNET.Core/NumPy/NumPyUtils.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Text;

namespace Tensorflow.NumPy
{
internal class NumPyUtils
{
public static TF_DataType GetResultType(params TF_DataType[] dtypes)
{
var resultDType = dtypes[0];
for(int i = 1; i < dtypes.Length; i++)
{
if (dtypes[i].get_datatype_size() > resultDType.get_datatype_size())
resultDType = dtypes[i];
}
return resultDType;
}
}
}

+ 66
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -929,6 +929,72 @@ namespace Tensorflow
throw new NotImplementedException("tensordot");
}

public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null)
{
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
{
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple))))
{
var shape_a = a.shape.dims;

// axes
int iter = 0;
foreach (int i in axes)
{
if (i >= 0)
axes[0 + iter] = i;
else
axes[0 + iter] = i + len(shape_a);
iter++;
}

// free
int[] free = { };
iter = 0;
foreach (int i in Enumerable.Range(0, len(axes)))
if (!Array.Exists(axes, i => i == i))
free[free.Length] = i;

// free_dims
int[] free_dims = { };
foreach (int i in free)
free_dims[free_dims.Length] = (int)shape_a[i];

int prod_free = (int)np.prod(free_dims);

// prod_axes
int[] prod_axes_pre = { };
foreach (int i in axes)
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i];
int prod_axes = (int)np.prod(prod_axes_pre);

// perm
Tensor perm;
if (flipped)
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free);
else
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free)
+ ops.convert_to_tensor(list(axes));

// new_shape
Shape new_shape;
if (flipped)
new_shape = new Shape(new int[] { prod_axes, prod_free });
else
new_shape = new Shape(new int[] { prod_free, prod_axes });
}

throw new NotImplementedException("_tensordot_reshape");
}

return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope =>
{
name = scope;
var (a_axes, b_axes) = (axes[0], axes[1]);
return x;
});
}

public static Tensor truediv(Tensor x, Tensor y, string name = null)
=> _truediv_python3(x, y, name);



+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow
/// <summary>
/// The name of the device on which this tensor will be produced, or null.
/// </summary>
public virtual string Device => op.Device;
public virtual string Device => op?.Device;
public long[] dims => shape.dims;

/// <summary>


+ 32
- 0
test/TensorFlowNET.UnitTest/NumPy/Statistics.Test.cs View File

@@ -0,0 +1,32 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.NumPy;

namespace TensorFlowNET.UnitTest.NumPy
{
/// <summary>
/// https://numpy.org/doc/stable/reference/routines.statistics.html
/// </summary>
[TestClass]
public class StatisticsTest : EagerModeTestBase
{
[TestMethod]
public void average()
{
var data = np.arange(1, 5);
var avg = np.average(data);
Assert.AreEqual(avg, 2.5);

data = np.arange(6).reshape((3, 2));
avg = np.average(data, axis: 1);
assertAllEqual(avg.ToArray<double>(), new[] { 0.5, 2.5, 4.5 });

// avg = np.average(data, axis: 1, weights: new[] { 1.0 / 4, 3.0 / 4 });
// assertAllEqual(avg.ToArray<double>(), new[] { 0.75, 2.75, 4.75 });
}
}
}

Loading…
Cancel
Save