diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 8574b838..be614294 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -152,7 +152,7 @@ namespace Tensorflow /// /// /// - public Tensor transpose(T1 a, Shape perm = null, string name = "transpose", bool conjugate = false) + public Tensor transpose(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false) => array_ops.transpose(a, perm, name, conjugate); /// diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index 7d4e418a..2b6051e0 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ +using Tensorflow.NumPy; using static Tensorflow.Binding; namespace Tensorflow @@ -40,13 +41,20 @@ namespace Tensorflow public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); + + public Tensor inv(Tensor input, bool adjoint = false, string name = null) + => ops.matrix_inverse(input, adjoint: adjoint, name: name); + + public Tensor lstsq(Tensor matrix, Tensor rhs, + NDArray l2_regularizer = null, bool fast = true, string name = null) + => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); } public Tensor diag(Tensor diagonal, string name = null) => gen_array_ops.diag(diagonal, name: name); - public Tensor matmul(Tensor a, Tensor b) - => math_ops.matmul(a, b); + public Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) + => math_ops.matmul(a, b, transpose_a: transpose_a, transpose_b: transpose_b); /// /// Multiply slices of the two matrices "x" and "y". diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 45f05ed7..4c7b6488 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -50,6 +50,9 @@ namespace Tensorflow public static implicit operator Tensor(Axis axis) => constant_op.constant(axis); + + public override string ToString() + => $"({string.Join(", ", axis)})"; } } diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs new file mode 100644 index 00000000..92ef6b69 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class LinearAlgebraImpl + { + public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn") + { + return a; + } + } +} diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index 515c3dcb..3b5e028a 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -48,7 +48,7 @@ namespace Tensorflow.NumPy => new NDArray(value); public static implicit operator Tensor(NDArray nd) - => nd._tensor; + => nd?._tensor; public static implicit operator NDArray(Tensor tensor) => new NDArray(tensor); diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs index 85cbeb71..7131b425 100644 --- a/src/TensorFlowNET.Core/Numpy/Numpy.cs +++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs @@ -105,5 +105,7 @@ namespace Tensorflow.NumPy { throw new NotImplementedException(""); } + + public static LinearAlgebraImpl linalg = new LinearAlgebraImpl(); } } diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index a1068215..263550e3 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -38,6 +38,16 @@ namespace Tensorflow } } + #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges + public int Length => ndim; + public long[] Slice(int start, int length) + { + var slice = new long[length]; + Array.Copy(_dims, start, slice, 0, length); + return slice; + } + #endregion + private Shape() { } @@ -107,7 +117,7 @@ namespace Tensorflow public long this[int n] { - get => dims[n]; + get => n < 0 ? dims[ndim + n] : dims[n]; set => dims[n] = value; } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 88bfb237..b1f7e41b 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -774,10 +774,10 @@ namespace Tensorflow int k = 0, int num_rows = -1, int num_cols = -1, - double padding_value = 0, + float padding_value = 0f, string align = "RIGHT_LEFT") => tf.Context.ExecuteOp("MatrixDiagV3", name, - new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value) + new ExecuteOpArgs(diagonal, k, num_rows, num_cols, ops.convert_to_tensor(padding_value, dtype: diagonal.dtype)) .SetAttributes(new { align })); public static Tensor matrix_set_diag(Tensor input, @@ -900,7 +900,7 @@ namespace Tensorflow return gen_array_ops.gather_v2(@params, indices, axis, name: name); } - public static Tensor transpose(T1 a, Shape perm, string name = "transpose", bool conjugate = false) + public static Tensor transpose(T1 a, Axis perm, string name = "transpose", bool conjugate = false) { return tf_with(ops.name_scope(name, "transpose", new { a }), scope => { diff --git a/src/TensorFlowNET.Core/Operations/linalg_ops.cs b/src/TensorFlowNET.Core/Operations/linalg_ops.cs index 33fbe953..6a0b869c 100644 --- a/src/TensorFlowNET.Core/Operations/linalg_ops.cs +++ b/src/TensorFlowNET.Core/Operations/linalg_ops.cs @@ -20,11 +20,12 @@ namespace Tensorflow var diag_size = Math.Min(num_rows, num_columns); if (batch_shape == null) batch_shape = new Shape(new int[0]); - var diag_shape = batch_shape.dims.concat(new long[] { diag_size }); + var batch_shape_tensor = ops.convert_to_tensor(batch_shape, dtype: tf.int32, name: "shape"); + var diag_shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { diag_size }) }, axis: 0); - long[] shape = null; + Tensor shape = null; if (!is_square) - shape = batch_shape.dims.concat(new long[] { num_rows, num_columns }); + shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { num_rows, num_columns }) }, axis: 0); var diag_ones = array_ops.ones(diag_shape, dtype: dtype); if (is_square) @@ -36,5 +37,81 @@ namespace Tensorflow } }); } + + public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) + => tf.Context.ExecuteOp("MatrixInverse", name, + new ExecuteOpArgs(input).SetAttributes(new + { + adjoint + })); + + public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs, + Tensor l2_regularizer = null, bool fast = true, string name = null) + { + return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer); + } + + Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) + { + Shape matrix_shape = matrix.shape[^2..]; + if (matrix_shape.IsFullyDefined) + { + if (matrix_shape[-2] >= matrix_shape[-1]) + return _overdetermined(matrix, rhs, l2_regularizer); + else + return _underdetermined(matrix, rhs, l2_regularizer); + } + + throw new NotImplementedException(""); + } + + Tensor _overdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) + { + var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: true); + return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a: true)); + } + + Tensor _underdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) + { + var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: false); + return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a: true); + } + + Tensor _RegularizedGramianCholesky(Tensor matrix, Tensor l2_regularizer, bool first_kind) + { + var gramian = math_ops.matmul(matrix, matrix, adjoint_a: first_kind, adjoint_b: !first_kind); + + if (l2_regularizer != null) + { + var matrix_shape = array_ops.shape(matrix); + var batch_shape = matrix_shape[":-2"]; + var small_dim = first_kind ? matrix_shape[-1] : matrix_shape[-2]; + var identity = eye(small_dim.numpy(), batch_shape: batch_shape.shape, dtype: matrix.dtype); + var small_dim_static = matrix.shape[first_kind ? -1 : -2]; + identity.shape = matrix.shape[..^2].concat(new[] { small_dim_static, small_dim_static }); + gramian += l2_regularizer * identity; + } + + return cholesky(gramian); + } + + public Tensor cholesky(Tensor input, string name = null) + => tf.Context.ExecuteOp("Cholesky", name, new ExecuteOpArgs(input)); + + public Tensor cholesky_solve(Tensor chol, Tensor rhs, string name = null) + => tf_with(ops.name_scope(name, default_name: "eye", new { chol, rhs }), scope => + { + var y = matrix_triangular_solve(chol, rhs, adjoint: false, lower: true); + var x = matrix_triangular_solve(chol, y, adjoint: true, lower: true); + return x; + }); + + public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = true, bool adjoint = false, string name = null) + => tf.Context.ExecuteOp("MatrixTriangularSolve", name, + new ExecuteOpArgs(matrix, rhs).SetAttributes(new + { + lower, + adjoint + })); } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 84094a6f..4fb481da 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -791,6 +791,18 @@ namespace Tensorflow if (transpose_b && adjoint_b) throw new ValueError("Only one of transpose_b and adjoint_b can be True."); + if(adjoint_a) + { + a = conj(a); + transpose_a = true; + } + + if (adjoint_b) + { + b = conj(b); + transpose_b = true; + } + result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); }); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 3c185cb4..fca4169c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -103,6 +103,8 @@ namespace Tensorflow public bool IsCreatedInGraphMode => isCreatedInGraphMode; public bool IsSparseTensor => this is SparseTensor; + public Tensor TensorShape => tf.shape(this); + /// /// Returns the shape of a tensor. /// diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 5fa6bdd9..5f2d74bd 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -166,7 +166,7 @@ namespace Tensorflow RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), - Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), + Shape ts => constant_op.constant(ts.size == 0 ? new long[0] : ts.dims, dtype: dtype, name: name), string str => constant_op.constant(str, dtype: tf.@string, name: name), string[] str => constant_op.constant(str, dtype: tf.@string, name: name), IEnumerable objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs index 6594651e..64b2d940 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs @@ -4,7 +4,7 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI { [TestClass] - public class LinalgTest + public class LinalgTest : EagerModeTestBase { [TestMethod] public void EyeTest() @@ -17,5 +17,33 @@ namespace TensorFlowNET.UnitTest.ManagedAPI Assert.AreEqual(0.0f, (double)tensor[2, 1]); Assert.AreEqual(1.0f, (double)tensor[2, 2]); } + + /// + /// https://colab.research.google.com/github/biswajitsahoo1111/blog_notebooks/blob/master/Doing_Linear_Algebra_using_Tensorflow_2.ipynb#scrollTo=6xfOcTFBL3Up + /// + [TestMethod] + public void LSTSQ() + { + var A_over = tf.constant(new float[,] { { 1, 2 }, { 2, 0.5f }, { 3, 1 }, { 4, 5.0f} }); + var A_under = tf.constant(new float[,] { { 3, 1, 2, 5 }, { 7, 9, 1, 4.0f } }); + var b_over = tf.constant(new float[] { 3, 4, 5, 6.0f }, shape: (4, 1)); + var b_under = tf.constant(new float[] { 7.2f, -5.8f }, shape: (2, 1)); + var x_over = tf.linalg.lstsq(A_over, b_over); + + var x = tf.matmul(tf.linalg.inv(tf.matmul(A_over, A_over, transpose_a: true)), tf.matmul(A_over, b_over, transpose_a: true)); + Assert.AreEqual(x_over.shape, (2, 1)); + AssetSequenceEqual(x_over.ToArray(), x.ToArray()); + + var x_under = tf.linalg.lstsq(A_under, b_under); + var y = tf.matmul(A_under, tf.matmul(tf.linalg.inv(tf.matmul(A_under, A_under, transpose_b: true)), b_under), transpose_a: true); + + Assert.AreEqual(x_under.shape, (4, 1)); + AssetSequenceEqual(x_under.ToArray(), y.ToArray()); + + var x_over_reg = tf.linalg.lstsq(A_over, b_over, l2_regularizer: 2.0f); + var x_under_reg = tf.linalg.lstsq(A_under, b_under, l2_regularizer: 2.0f); + Assert.AreEqual(x_under_reg.shape, (4, 1)); + AssetSequenceEqual(x_under_reg.ToArray(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f }); + } } } diff --git a/test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs new file mode 100644 index 00000000..81c5e2c3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs @@ -0,0 +1,23 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.NumPy; + +namespace TensorFlowNET.UnitTest.NumPy +{ + /// + /// https://numpy.org/doc/stable/reference/generated/numpy.prod.html + /// + [TestClass] + public class LinearAlgebraTest : EagerModeTestBase + { + [TestMethod] + public void lstsq() + { + + } + } +} diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 8ab25ee6..f0246337 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -62,6 +62,20 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); } + public void AssetSequenceEqual(float[] expected, float[] actual) + { + float eps = 1e-5f; + for (int i = 0; i < expected.Length; i++) + Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}"); + } + + public void AssetSequenceEqual(double[] expected, double[] actual) + { + double eps = 1e-5f; + for (int i = 0; i < expected.Length; i++) + Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}"); + } + public void assertEqual(object given, object expected) { /*if (given is NDArray && expected is NDArray)