@@ -0,0 +1,31 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
namespace Tensorflow.NumPy | |||
{ | |||
public partial class NumPyImpl | |||
{ | |||
public NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | |||
{ | |||
var dtype = NumPyUtils.GetResultType(a.dtype, np.float64); | |||
if(weights is null) | |||
{ | |||
var tensorA = math_ops.cast(a, dtype); | |||
var nd = math_ops.reduce_mean(tensorA, axis); | |||
return new NDArray(nd); | |||
} | |||
else | |||
{ | |||
var tensorW = math_ops.cast(weights, dtype); | |||
if(a.rank != weights.rank) | |||
{ | |||
var weights_sum = math_ops.reduce_sum(tensorW); | |||
var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } }); | |||
var avg = math_ops.tensordot(a, weights, axes) / weights_sum; | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} | |||
} |
@@ -14,5 +14,9 @@ namespace Tensorflow.NumPy | |||
[AutoNumPy] | |||
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); | |||
[AutoNumPy] | |||
public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) | |||
=> tf.numpy.average(a, axis: axis, weights: weights, returned: returned); | |||
} | |||
} |
@@ -0,0 +1,19 @@ | |||
using System; | |||
using System.Text; | |||
namespace Tensorflow.NumPy | |||
{ | |||
internal class NumPyUtils | |||
{ | |||
public static TF_DataType GetResultType(params TF_DataType[] dtypes) | |||
{ | |||
var resultDType = dtypes[0]; | |||
for(int i = 1; i < dtypes.Length; i++) | |||
{ | |||
if (dtypes[i].get_datatype_size() > resultDType.get_datatype_size()) | |||
resultDType = dtypes[i]; | |||
} | |||
return resultDType; | |||
} | |||
} | |||
} |
@@ -929,6 +929,72 @@ namespace Tensorflow | |||
throw new NotImplementedException("tensordot"); | |||
} | |||
public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null) | |||
{ | |||
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) | |||
{ | |||
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple)))) | |||
{ | |||
var shape_a = a.shape.dims; | |||
// axes | |||
int iter = 0; | |||
foreach (int i in axes) | |||
{ | |||
if (i >= 0) | |||
axes[0 + iter] = i; | |||
else | |||
axes[0 + iter] = i + len(shape_a); | |||
iter++; | |||
} | |||
// free | |||
int[] free = { }; | |||
iter = 0; | |||
foreach (int i in Enumerable.Range(0, len(axes))) | |||
if (!Array.Exists(axes, i => i == i)) | |||
free[free.Length] = i; | |||
// free_dims | |||
int[] free_dims = { }; | |||
foreach (int i in free) | |||
free_dims[free_dims.Length] = (int)shape_a[i]; | |||
int prod_free = (int)np.prod(free_dims); | |||
// prod_axes | |||
int[] prod_axes_pre = { }; | |||
foreach (int i in axes) | |||
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; | |||
int prod_axes = (int)np.prod(prod_axes_pre); | |||
// perm | |||
Tensor perm; | |||
if (flipped) | |||
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); | |||
else | |||
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) | |||
+ ops.convert_to_tensor(list(axes)); | |||
// new_shape | |||
Shape new_shape; | |||
if (flipped) | |||
new_shape = new Shape(new int[] { prod_axes, prod_free }); | |||
else | |||
new_shape = new Shape(new int[] { prod_free, prod_axes }); | |||
} | |||
throw new NotImplementedException("_tensordot_reshape"); | |||
} | |||
return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope => | |||
{ | |||
name = scope; | |||
var (a_axes, b_axes) = (axes[0], axes[1]); | |||
return x; | |||
}); | |||
} | |||
public static Tensor truediv(Tensor x, Tensor y, string name = null) | |||
=> _truediv_python3(x, y, name); | |||
@@ -78,7 +78,7 @@ namespace Tensorflow | |||
/// <summary> | |||
/// The name of the device on which this tensor will be produced, or null. | |||
/// </summary> | |||
public virtual string Device => op.Device; | |||
public virtual string Device => op?.Device; | |||
public long[] dims => shape.dims; | |||
/// <summary> | |||
@@ -0,0 +1,32 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow; | |||
using Tensorflow.NumPy; | |||
namespace TensorFlowNET.UnitTest.NumPy | |||
{ | |||
/// <summary> | |||
/// https://numpy.org/doc/stable/reference/routines.statistics.html | |||
/// </summary> | |||
[TestClass] | |||
public class StatisticsTest : EagerModeTestBase | |||
{ | |||
[TestMethod] | |||
public void average() | |||
{ | |||
var data = np.arange(1, 5); | |||
var avg = np.average(data); | |||
Assert.AreEqual(avg, 2.5); | |||
data = np.arange(6).reshape((3, 2)); | |||
avg = np.average(data, axis: 1); | |||
assertAllEqual(avg.ToArray<double>(), new[] { 0.5, 2.5, 4.5 }); | |||
// avg = np.average(data, axis: 1, weights: new[] { 1.0 / 4, 3.0 / 4 }); | |||
// assertAllEqual(avg.ToArray<double>(), new[] { 0.75, 2.75, 4.75 }); | |||
} | |||
} | |||
} |