@@ -152,7 +152,7 @@ namespace Tensorflow | |||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <param name="conjugate"></param> | /// <param name="conjugate"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor transpose<T1>(T1 a, Shape perm = null, string name = "transpose", bool conjugate = false) | |||||
public Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false) | |||||
=> array_ops.transpose(a, perm, name, conjugate); | => array_ops.transpose(a, perm, name, conjugate); | ||||
/// <summary> | /// <summary> | ||||
@@ -13,6 +13,7 @@ | |||||
See the License for the specific language governing permissions and | See the License for the specific language governing permissions and | ||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -40,13 +41,20 @@ namespace Tensorflow | |||||
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) | public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) | ||||
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); | => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); | ||||
public Tensor inv(Tensor input, bool adjoint = false, string name = null) | |||||
=> ops.matrix_inverse(input, adjoint: adjoint, name: name); | |||||
public Tensor lstsq(Tensor matrix, Tensor rhs, | |||||
NDArray l2_regularizer = null, bool fast = true, string name = null) | |||||
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name); | |||||
} | } | ||||
public Tensor diag(Tensor diagonal, string name = null) | public Tensor diag(Tensor diagonal, string name = null) | ||||
=> gen_array_ops.diag(diagonal, name: name); | => gen_array_ops.diag(diagonal, name: name); | ||||
public Tensor matmul(Tensor a, Tensor b) | |||||
=> math_ops.matmul(a, b); | |||||
public Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) | |||||
=> math_ops.matmul(a, b, transpose_a: transpose_a, transpose_b: transpose_b); | |||||
/// <summary> | /// <summary> | ||||
/// Multiply slices of the two matrices "x" and "y". | /// Multiply slices of the two matrices "x" and "y". | ||||
@@ -50,6 +50,9 @@ namespace Tensorflow | |||||
public static implicit operator Tensor(Axis axis) | public static implicit operator Tensor(Axis axis) | ||||
=> constant_op.constant(axis); | => constant_op.constant(axis); | ||||
public override string ToString() | |||||
=> $"({string.Join(", ", axis)})"; | |||||
} | } | ||||
} | } | ||||
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.NumPy | |||||
{ | |||||
public class LinearAlgebraImpl | |||||
{ | |||||
public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn") | |||||
{ | |||||
return a; | |||||
} | |||||
} | |||||
} |
@@ -48,7 +48,7 @@ namespace Tensorflow.NumPy | |||||
=> new NDArray(value); | => new NDArray(value); | ||||
public static implicit operator Tensor(NDArray nd) | public static implicit operator Tensor(NDArray nd) | ||||
=> nd._tensor; | |||||
=> nd?._tensor; | |||||
public static implicit operator NDArray(Tensor tensor) | public static implicit operator NDArray(Tensor tensor) | ||||
=> new NDArray(tensor); | => new NDArray(tensor); | ||||
@@ -105,5 +105,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
public static LinearAlgebraImpl linalg = new LinearAlgebraImpl(); | |||||
} | } | ||||
} | } |
@@ -38,6 +38,16 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | |||||
public int Length => ndim; | |||||
public long[] Slice(int start, int length) | |||||
{ | |||||
var slice = new long[length]; | |||||
Array.Copy(_dims, start, slice, 0, length); | |||||
return slice; | |||||
} | |||||
#endregion | |||||
private Shape() | private Shape() | ||||
{ | { | ||||
} | } | ||||
@@ -107,7 +117,7 @@ namespace Tensorflow | |||||
public long this[int n] | public long this[int n] | ||||
{ | { | ||||
get => dims[n]; | |||||
get => n < 0 ? dims[ndim + n] : dims[n]; | |||||
set => dims[n] = value; | set => dims[n] = value; | ||||
} | } | ||||
@@ -774,10 +774,10 @@ namespace Tensorflow | |||||
int k = 0, | int k = 0, | ||||
int num_rows = -1, | int num_rows = -1, | ||||
int num_cols = -1, | int num_cols = -1, | ||||
double padding_value = 0, | |||||
float padding_value = 0f, | |||||
string align = "RIGHT_LEFT") | string align = "RIGHT_LEFT") | ||||
=> tf.Context.ExecuteOp("MatrixDiagV3", name, | => tf.Context.ExecuteOp("MatrixDiagV3", name, | ||||
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value) | |||||
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, ops.convert_to_tensor(padding_value, dtype: diagonal.dtype)) | |||||
.SetAttributes(new { align })); | .SetAttributes(new { align })); | ||||
public static Tensor matrix_set_diag(Tensor input, | public static Tensor matrix_set_diag(Tensor input, | ||||
@@ -900,7 +900,7 @@ namespace Tensorflow | |||||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | return gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
} | } | ||||
public static Tensor transpose<T1>(T1 a, Shape perm, string name = "transpose", bool conjugate = false) | |||||
public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false) | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | ||||
{ | { | ||||
@@ -20,11 +20,12 @@ namespace Tensorflow | |||||
var diag_size = Math.Min(num_rows, num_columns); | var diag_size = Math.Min(num_rows, num_columns); | ||||
if (batch_shape == null) | if (batch_shape == null) | ||||
batch_shape = new Shape(new int[0]); | batch_shape = new Shape(new int[0]); | ||||
var diag_shape = batch_shape.dims.concat(new long[] { diag_size }); | |||||
var batch_shape_tensor = ops.convert_to_tensor(batch_shape, dtype: tf.int32, name: "shape"); | |||||
var diag_shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { diag_size }) }, axis: 0); | |||||
long[] shape = null; | |||||
Tensor shape = null; | |||||
if (!is_square) | if (!is_square) | ||||
shape = batch_shape.dims.concat(new long[] { num_rows, num_columns }); | |||||
shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { num_rows, num_columns }) }, axis: 0); | |||||
var diag_ones = array_ops.ones(diag_shape, dtype: dtype); | var diag_ones = array_ops.ones(diag_shape, dtype: dtype); | ||||
if (is_square) | if (is_square) | ||||
@@ -36,5 +37,81 @@ namespace Tensorflow | |||||
} | } | ||||
}); | }); | ||||
} | } | ||||
public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null) | |||||
=> tf.Context.ExecuteOp("MatrixInverse", name, | |||||
new ExecuteOpArgs(input).SetAttributes(new | |||||
{ | |||||
adjoint | |||||
})); | |||||
public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs, | |||||
Tensor l2_regularizer = null, bool fast = true, string name = null) | |||||
{ | |||||
return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer); | |||||
} | |||||
Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) | |||||
{ | |||||
Shape matrix_shape = matrix.shape[^2..]; | |||||
if (matrix_shape.IsFullyDefined) | |||||
{ | |||||
if (matrix_shape[-2] >= matrix_shape[-1]) | |||||
return _overdetermined(matrix, rhs, l2_regularizer); | |||||
else | |||||
return _underdetermined(matrix, rhs, l2_regularizer); | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
Tensor _overdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) | |||||
{ | |||||
var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: true); | |||||
return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a: true)); | |||||
} | |||||
Tensor _underdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) | |||||
{ | |||||
var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: false); | |||||
return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a: true); | |||||
} | |||||
Tensor _RegularizedGramianCholesky(Tensor matrix, Tensor l2_regularizer, bool first_kind) | |||||
{ | |||||
var gramian = math_ops.matmul(matrix, matrix, adjoint_a: first_kind, adjoint_b: !first_kind); | |||||
if (l2_regularizer != null) | |||||
{ | |||||
var matrix_shape = array_ops.shape(matrix); | |||||
var batch_shape = matrix_shape[":-2"]; | |||||
var small_dim = first_kind ? matrix_shape[-1] : matrix_shape[-2]; | |||||
var identity = eye(small_dim.numpy(), batch_shape: batch_shape.shape, dtype: matrix.dtype); | |||||
var small_dim_static = matrix.shape[first_kind ? -1 : -2]; | |||||
identity.shape = matrix.shape[..^2].concat(new[] { small_dim_static, small_dim_static }); | |||||
gramian += l2_regularizer * identity; | |||||
} | |||||
return cholesky(gramian); | |||||
} | |||||
public Tensor cholesky(Tensor input, string name = null) | |||||
=> tf.Context.ExecuteOp("Cholesky", name, new ExecuteOpArgs(input)); | |||||
public Tensor cholesky_solve(Tensor chol, Tensor rhs, string name = null) | |||||
=> tf_with(ops.name_scope(name, default_name: "eye", new { chol, rhs }), scope => | |||||
{ | |||||
var y = matrix_triangular_solve(chol, rhs, adjoint: false, lower: true); | |||||
var x = matrix_triangular_solve(chol, y, adjoint: true, lower: true); | |||||
return x; | |||||
}); | |||||
public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = true, bool adjoint = false, string name = null) | |||||
=> tf.Context.ExecuteOp("MatrixTriangularSolve", name, | |||||
new ExecuteOpArgs(matrix, rhs).SetAttributes(new | |||||
{ | |||||
lower, | |||||
adjoint | |||||
})); | |||||
} | } | ||||
} | } |
@@ -791,6 +791,18 @@ namespace Tensorflow | |||||
if (transpose_b && adjoint_b) | if (transpose_b && adjoint_b) | ||||
throw new ValueError("Only one of transpose_b and adjoint_b can be True."); | throw new ValueError("Only one of transpose_b and adjoint_b can be True."); | ||||
if(adjoint_a) | |||||
{ | |||||
a = conj(a); | |||||
transpose_a = true; | |||||
} | |||||
if (adjoint_b) | |||||
{ | |||||
b = conj(b); | |||||
transpose_b = true; | |||||
} | |||||
result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); | result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name); | ||||
}); | }); | ||||
@@ -103,6 +103,8 @@ namespace Tensorflow | |||||
public bool IsCreatedInGraphMode => isCreatedInGraphMode; | public bool IsCreatedInGraphMode => isCreatedInGraphMode; | ||||
public bool IsSparseTensor => this is SparseTensor; | public bool IsSparseTensor => this is SparseTensor; | ||||
public Tensor TensorShape => tf.shape(this); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -166,7 +166,7 @@ namespace Tensorflow | |||||
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | ||||
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | ||||
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), | Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), | ||||
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), | |||||
Shape ts => constant_op.constant(ts.size == 0 ? new long[0] : ts.dims, dtype: dtype, name: name), | |||||
string str => constant_op.constant(str, dtype: tf.@string, name: name), | string str => constant_op.constant(str, dtype: tf.@string, name: name), | ||||
string[] str => constant_op.constant(str, dtype: tf.@string, name: name), | string[] str => constant_op.constant(str, dtype: tf.@string, name: name), | ||||
IEnumerable<object> objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), | IEnumerable<object> objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), | ||||
@@ -4,7 +4,7 @@ using static Tensorflow.Binding; | |||||
namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class LinalgTest | |||||
public class LinalgTest : EagerModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void EyeTest() | public void EyeTest() | ||||
@@ -17,5 +17,33 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
Assert.AreEqual(0.0f, (double)tensor[2, 1]); | Assert.AreEqual(0.0f, (double)tensor[2, 1]); | ||||
Assert.AreEqual(1.0f, (double)tensor[2, 2]); | Assert.AreEqual(1.0f, (double)tensor[2, 2]); | ||||
} | } | ||||
/// <summary> | |||||
/// https://colab.research.google.com/github/biswajitsahoo1111/blog_notebooks/blob/master/Doing_Linear_Algebra_using_Tensorflow_2.ipynb#scrollTo=6xfOcTFBL3Up | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void LSTSQ() | |||||
{ | |||||
var A_over = tf.constant(new float[,] { { 1, 2 }, { 2, 0.5f }, { 3, 1 }, { 4, 5.0f} }); | |||||
var A_under = tf.constant(new float[,] { { 3, 1, 2, 5 }, { 7, 9, 1, 4.0f } }); | |||||
var b_over = tf.constant(new float[] { 3, 4, 5, 6.0f }, shape: (4, 1)); | |||||
var b_under = tf.constant(new float[] { 7.2f, -5.8f }, shape: (2, 1)); | |||||
var x_over = tf.linalg.lstsq(A_over, b_over); | |||||
var x = tf.matmul(tf.linalg.inv(tf.matmul(A_over, A_over, transpose_a: true)), tf.matmul(A_over, b_over, transpose_a: true)); | |||||
Assert.AreEqual(x_over.shape, (2, 1)); | |||||
AssetSequenceEqual(x_over.ToArray<float>(), x.ToArray<float>()); | |||||
var x_under = tf.linalg.lstsq(A_under, b_under); | |||||
var y = tf.matmul(A_under, tf.matmul(tf.linalg.inv(tf.matmul(A_under, A_under, transpose_b: true)), b_under), transpose_a: true); | |||||
Assert.AreEqual(x_under.shape, (4, 1)); | |||||
AssetSequenceEqual(x_under.ToArray<float>(), y.ToArray<float>()); | |||||
var x_over_reg = tf.linalg.lstsq(A_over, b_over, l2_regularizer: 2.0f); | |||||
var x_under_reg = tf.linalg.lstsq(A_under, b_under, l2_regularizer: 2.0f); | |||||
Assert.AreEqual(x_under_reg.shape, (4, 1)); | |||||
AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f }); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,23 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | |||||
namespace TensorFlowNET.UnitTest.NumPy | |||||
{ | |||||
/// <summary> | |||||
/// https://numpy.org/doc/stable/reference/generated/numpy.prod.html | |||||
/// </summary> | |||||
[TestClass] | |||||
public class LinearAlgebraTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void lstsq() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -62,6 +62,20 @@ namespace TensorFlowNET.UnitTest | |||||
Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); | Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); | ||||
} | } | ||||
public void AssetSequenceEqual(float[] expected, float[] actual) | |||||
{ | |||||
float eps = 1e-5f; | |||||
for (int i = 0; i < expected.Length; i++) | |||||
Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}"); | |||||
} | |||||
public void AssetSequenceEqual(double[] expected, double[] actual) | |||||
{ | |||||
double eps = 1e-5f; | |||||
for (int i = 0; i < expected.Length; i++) | |||||
Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}"); | |||||
} | |||||
public void assertEqual(object given, object expected) | public void assertEqual(object given, object expected) | ||||
{ | { | ||||
/*if (given is NDArray && expected is NDArray) | /*if (given is NDArray && expected is NDArray) | ||||