diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index 956c52be..5b79d138 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -57,6 +57,9 @@ namespace Tensorflow public Tensor lstsq(Tensor matrix, Tensor rhs, NDArray l2_regularizer = null, bool fast = true, string name = null) => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); + + public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null) + => math_ops.tensordot(x, y, axes, name: name); } public Tensor diag(Tensor diagonal, string name = null) diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs index 990c2ad6..bc6047eb 100644 --- a/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs +++ b/src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs @@ -20,7 +20,7 @@ namespace Tensorflow.NumPy if(a.rank != weights.rank) { var weights_sum = math_ops.reduce_sum(tensorW); - var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } }); + var axes = np.array(new[,] { { axis }, { 0 } }); var avg = math_ops.tensordot(a, weights, axes) / weights_sum; } diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index aea4e678..832a6658 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -104,6 +104,14 @@ namespace Tensorflow.NumPy if (shape.ndim != shape3.Length) return false; return Enumerable.SequenceEqual(shape.as_int_list(), shape3); + case List shape4: + if (shape.ndim != shape4.Count) + return false; + return Enumerable.SequenceEqual(shape.dims, shape4); + case List shape5: + if (shape.ndim != shape5.Count) + return false; + return Enumerable.SequenceEqual(shape.as_int_list(), shape5); default: return false; } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 8a058ea4..861dba18 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -868,133 +868,92 @@ namespace Tensorflow public static Tensor tanh(Tensor x, string name = null) => gen_math_ops.tanh(x, name); - public static Tensor tensordot(Tensor x, Tensor y, int[] axes, string name = null) + public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = null) { - Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) + return tf_with(ops.name_scope(name, "Tensordot", new { a, b, axes }), scope => { - if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List), typeof(Tuple)))) - { - var shape_a = a.shape.dims; - - // axes - int iter = 0; - foreach (int i in axes) - { - if (i >= 0) - axes[0 + iter] = i; - else - axes[0 + iter] = i + len(shape_a); - iter++; - } - - // free - int[] free = { }; - iter = 0; - foreach (int i in Enumerable.Range(0, len(axes))) - if (!Array.Exists(axes, i => i == i)) - free[free.Length] = i; - - // free_dims - int[] free_dims = { }; - foreach (int i in free) - free_dims[free_dims.Length] = (int)shape_a[i]; - - int prod_free = (int)np.prod(free_dims); - - // prod_axes - int[] prod_axes_pre = { }; - foreach (int i in axes) - prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; - int prod_axes = (int)np.prod(prod_axes_pre); - - // perm - Tensor perm; - if (flipped) - perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); - else - perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) - + ops.convert_to_tensor(list(axes)); - - // new_shape - Shape new_shape; - if (flipped) - new_shape = new Shape(new int[] { prod_axes, prod_free }); - else - new_shape = new Shape(new int[] { prod_free, prod_axes }); - } + name = scope; + var (a_axes, b_axes) = _tensordot_axes(a, axes); + var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes); + var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true); + var ab_matmul = matmul(a_reshape, b_reshape); + var dims = new List(); + dims.AddRange(a_free_dims); + dims.AddRange(b_free_dims); + if (ab_matmul.shape.Equals(dims)) + return ab_matmul; + else + return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name); + }); + } - throw new NotImplementedException("_tensordot_reshape"); + static (int[], int[]) _tensordot_axes(Tensor a, NDArray axes) + { + if (axes.rank == 0) + { + int axe = axes; + if (axe > a.shape.ndim) + throw new ValueError("`axes` must not be larger than the number of " + + $"dimensions of tensor {a}. Received {axes}, vs " + + $"tensor dimensions {a.ndim}."); + return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(), + Binding.range(0, axe).ToArray()); + } + else + { + (int a_axe, int b_axe) = (axes[0], axes[1]); + return (new[] { a_axe }, new[] { b_axe }); } - - throw new NotImplementedException("tensordot"); } - public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null) + static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) { - Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false) + if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple)))) { - if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List), typeof(Tuple)))) - { - var shape_a = a.shape.dims; + var shape_a = a.shape.as_int_list(); - // axes - int iter = 0; - foreach (int i in axes) - { - if (i >= 0) - axes[0 + iter] = i; - else - axes[0 + iter] = i + len(shape_a); - iter++; - } + // axes + axes = axes.Select(i => i >= 0 ? i : i + len(shape_a)).ToArray(); + + // free + int[] free = Binding.range(a.shape.ndim).Where(i => !axes.Contains(i)).ToArray(); + + // free_dims + int[] free_dims = free.Select(i => shape_a[i]).ToArray(); - // free - int[] free = { }; - iter = 0; - foreach (int i in Enumerable.Range(0, len(axes))) - if (!Array.Exists(axes, i => i == i)) - free[free.Length] = i; - - // free_dims - int[] free_dims = { }; - foreach (int i in free) - free_dims[free_dims.Length] = (int)shape_a[i]; - - int prod_free = (int)np.prod(free_dims); - - // prod_axes - int[] prod_axes_pre = { }; - foreach (int i in axes) - prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i]; - int prod_axes = (int)np.prod(prod_axes_pre); - - // perm - Tensor perm; - if (flipped) - perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free); - else - perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free) - + ops.convert_to_tensor(list(axes)); - - // new_shape - Shape new_shape; - if (flipped) - new_shape = new Shape(new int[] { prod_axes, prod_free }); - else - new_shape = new Shape(new int[] { prod_free, prod_axes }); + int prod_free = np.prod(free_dims); + + // prod_axes + int prod_axes = np.prod(axes.Select(i => shape_a[i]).ToArray()); + + // perm + List perm = new List(); + if (flipped) + { + perm.AddRange(axes); + perm.AddRange(free); + } + else + { + perm.AddRange(free); + perm.AddRange(axes); } - throw new NotImplementedException("_tensordot_reshape"); + // new_shape + Shape new_shape; + if (flipped) + new_shape = new Shape(new int[] { prod_axes, prod_free }); + else + new_shape = new Shape(new int[] { prod_free, prod_axes }); + var a_trans = a; + var reshaped_a = array_ops.reshape(a_trans, new_shape); + return (reshaped_a, free_dims, free_dims); } - return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope => - { - name = scope; - var (a_axes, b_axes) = (axes[0], axes[1]); - return x; - }); + throw new NotImplementedException("_tensordot_reshape"); } + public static Tensor truediv(Tensor x, Tensor y, string name = null) => _truediv_python3(x, y, name); diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs index f7fb965b..45448cbb 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs @@ -63,5 +63,19 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var norm = tf.linalg.global_norm(t_list); Assert.AreEqual(norm.numpy(), 14.282857f); } + + [TestMethod] + public void Tensordot() + { + var a = tf.constant(new[] { 1, 2 }); + var b = tf.constant(new[] { 2, 3 }); + var c = tf.linalg.tensordot(a, b, 0); + Assert.AreEqual(c.shape, (2, 2)); + AssetSequenceEqual(c.ToArray(), new[] { 2, 3, 4, 6 }); + + c = tf.linalg.tensordot(a, b, new[] { 0, 0 }); + Assert.AreEqual(c.shape.ndim, 0); + Assert.AreEqual(c.numpy(), 8); + } } }