Browse Source

tf.linalg.lstsq #823

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
1cf190fba5
15 changed files with 206 additions and 13 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +10
    -2
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  3. +3
    -0
      src/TensorFlowNET.Core/NumPy/Axis.cs
  4. +14
    -0
      src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs
  5. +1
    -1
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Numpy/Numpy.cs
  7. +11
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  8. +3
    -3
      src/TensorFlowNET.Core/Operations/array_ops.cs
  9. +80
    -3
      src/TensorFlowNET.Core/Operations/linalg_ops.cs
  10. +12
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  11. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  12. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  13. +29
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs
  14. +23
    -0
      test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs
  15. +14
    -0
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -152,7 +152,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <param name="conjugate"></param>
/// <returns></returns>
public Tensor transpose<T1>(T1 a, Shape perm = null, string name = "transpose", bool conjugate = false)
public Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
=> array_ops.transpose(a, perm, name, conjugate);

/// <summary>


+ 10
- 2
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -40,13 +41,20 @@ namespace Tensorflow

public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name);

public Tensor inv(Tensor input, bool adjoint = false, string name = null)
=> ops.matrix_inverse(input, adjoint: adjoint, name: name);

public Tensor lstsq(Tensor matrix, Tensor rhs,
NDArray l2_regularizer = null, bool fast = true, string name = null)
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
}

public Tensor diag(Tensor diagonal, string name = null)
=> gen_array_ops.diag(diagonal, name: name);

public Tensor matmul(Tensor a, Tensor b)
=> math_ops.matmul(a, b);
public Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
=> math_ops.matmul(a, b, transpose_a: transpose_a, transpose_b: transpose_b);

/// <summary>
/// Multiply slices of the two matrices "x" and "y".


+ 3
- 0
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -50,6 +50,9 @@ namespace Tensorflow

public static implicit operator Tensor(Axis axis)
=> constant_op.constant(axis);

public override string ToString()
=> $"({string.Join(", ", axis)})";
}
}



+ 14
- 0
src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.NumPy
{
public class LinearAlgebraImpl
{
public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn")
{
return a;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.NumPy
=> new NDArray(value);

public static implicit operator Tensor(NDArray nd)
=> nd._tensor;
=> nd?._tensor;

public static implicit operator NDArray(Tensor tensor)
=> new NDArray(tensor);


+ 2
- 0
src/TensorFlowNET.Core/Numpy/Numpy.cs View File

@@ -105,5 +105,7 @@ namespace Tensorflow.NumPy
{
throw new NotImplementedException("");
}

public static LinearAlgebraImpl linalg = new LinearAlgebraImpl();
}
}

+ 11
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -38,6 +38,16 @@ namespace Tensorflow
}
}

#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
public int Length => ndim;
public long[] Slice(int start, int length)
{
var slice = new long[length];
Array.Copy(_dims, start, slice, 0, length);
return slice;
}
#endregion

private Shape()
{
}
@@ -107,7 +117,7 @@ namespace Tensorflow

public long this[int n]
{
get => dims[n];
get => n < 0 ? dims[ndim + n] : dims[n];
set => dims[n] = value;
}



+ 3
- 3
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -774,10 +774,10 @@ namespace Tensorflow
int k = 0,
int num_rows = -1,
int num_cols = -1,
double padding_value = 0,
float padding_value = 0f,
string align = "RIGHT_LEFT")
=> tf.Context.ExecuteOp("MatrixDiagV3", name,
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value)
new ExecuteOpArgs(diagonal, k, num_rows, num_cols, ops.convert_to_tensor(padding_value, dtype: diagonal.dtype))
.SetAttributes(new { align }));

public static Tensor matrix_set_diag(Tensor input,
@@ -900,7 +900,7 @@ namespace Tensorflow
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
}

public static Tensor transpose<T1>(T1 a, Shape perm, string name = "transpose", bool conjugate = false)
public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{


+ 80
- 3
src/TensorFlowNET.Core/Operations/linalg_ops.cs View File

@@ -20,11 +20,12 @@ namespace Tensorflow
var diag_size = Math.Min(num_rows, num_columns);
if (batch_shape == null)
batch_shape = new Shape(new int[0]);
var diag_shape = batch_shape.dims.concat(new long[] { diag_size });
var batch_shape_tensor = ops.convert_to_tensor(batch_shape, dtype: tf.int32, name: "shape");
var diag_shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { diag_size }) }, axis: 0);

long[] shape = null;
Tensor shape = null;
if (!is_square)
shape = batch_shape.dims.concat(new long[] { num_rows, num_columns });
shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { num_rows, num_columns }) }, axis: 0);

var diag_ones = array_ops.ones(diag_shape, dtype: dtype);
if (is_square)
@@ -36,5 +37,81 @@ namespace Tensorflow
}
});
}

public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
=> tf.Context.ExecuteOp("MatrixInverse", name,
new ExecuteOpArgs(input).SetAttributes(new
{
adjoint
}));

public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs,
Tensor l2_regularizer = null, bool fast = true, string name = null)
{
return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer);
}

Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
{
Shape matrix_shape = matrix.shape[^2..];
if (matrix_shape.IsFullyDefined)
{
if (matrix_shape[-2] >= matrix_shape[-1])
return _overdetermined(matrix, rhs, l2_regularizer);
else
return _underdetermined(matrix, rhs, l2_regularizer);
}

throw new NotImplementedException("");
}

Tensor _overdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
{
var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: true);
return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a: true));
}

Tensor _underdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
{
var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: false);
return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a: true);
}

Tensor _RegularizedGramianCholesky(Tensor matrix, Tensor l2_regularizer, bool first_kind)
{
var gramian = math_ops.matmul(matrix, matrix, adjoint_a: first_kind, adjoint_b: !first_kind);

if (l2_regularizer != null)
{
var matrix_shape = array_ops.shape(matrix);
var batch_shape = matrix_shape[":-2"];
var small_dim = first_kind ? matrix_shape[-1] : matrix_shape[-2];
var identity = eye(small_dim.numpy(), batch_shape: batch_shape.shape, dtype: matrix.dtype);
var small_dim_static = matrix.shape[first_kind ? -1 : -2];
identity.shape = matrix.shape[..^2].concat(new[] { small_dim_static, small_dim_static });
gramian += l2_regularizer * identity;
}

return cholesky(gramian);
}

public Tensor cholesky(Tensor input, string name = null)
=> tf.Context.ExecuteOp("Cholesky", name, new ExecuteOpArgs(input));

public Tensor cholesky_solve(Tensor chol, Tensor rhs, string name = null)
=> tf_with(ops.name_scope(name, default_name: "eye", new { chol, rhs }), scope =>
{
var y = matrix_triangular_solve(chol, rhs, adjoint: false, lower: true);
var x = matrix_triangular_solve(chol, y, adjoint: true, lower: true);
return x;
});

public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = true, bool adjoint = false, string name = null)
=> tf.Context.ExecuteOp("MatrixTriangularSolve", name,
new ExecuteOpArgs(matrix, rhs).SetAttributes(new
{
lower,
adjoint
}));
}
}

+ 12
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -791,6 +791,18 @@ namespace Tensorflow
if (transpose_b && adjoint_b)
throw new ValueError("Only one of transpose_b and adjoint_b can be True.");

if(adjoint_a)
{
a = conj(a);
transpose_a = true;
}

if (adjoint_b)
{
b = conj(b);
transpose_b = true;
}

result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name);
});



+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -103,6 +103,8 @@ namespace Tensorflow
public bool IsCreatedInGraphMode => isCreatedInGraphMode;
public bool IsSparseTensor => this is SparseTensor;

public Tensor TensorShape => tf.shape(this);

/// <summary>
/// Returns the shape of a tensor.
/// </summary>


+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -166,7 +166,7 @@ namespace Tensorflow
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name),
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
Shape ts => constant_op.constant(ts.size == 0 ? new long[0] : ts.dims, dtype: dtype, name: name),
string str => constant_op.constant(str, dtype: tf.@string, name: name),
string[] str => constant_op.constant(str, dtype: tf.@string, name: name),
IEnumerable<object> objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name),


+ 29
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs View File

@@ -4,7 +4,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
public class LinalgTest
public class LinalgTest : EagerModeTestBase
{
[TestMethod]
public void EyeTest()
@@ -17,5 +17,33 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreEqual(0.0f, (double)tensor[2, 1]);
Assert.AreEqual(1.0f, (double)tensor[2, 2]);
}

/// <summary>
/// https://colab.research.google.com/github/biswajitsahoo1111/blog_notebooks/blob/master/Doing_Linear_Algebra_using_Tensorflow_2.ipynb#scrollTo=6xfOcTFBL3Up
/// </summary>
[TestMethod]
public void LSTSQ()
{
var A_over = tf.constant(new float[,] { { 1, 2 }, { 2, 0.5f }, { 3, 1 }, { 4, 5.0f} });
var A_under = tf.constant(new float[,] { { 3, 1, 2, 5 }, { 7, 9, 1, 4.0f } });
var b_over = tf.constant(new float[] { 3, 4, 5, 6.0f }, shape: (4, 1));
var b_under = tf.constant(new float[] { 7.2f, -5.8f }, shape: (2, 1));
var x_over = tf.linalg.lstsq(A_over, b_over);

var x = tf.matmul(tf.linalg.inv(tf.matmul(A_over, A_over, transpose_a: true)), tf.matmul(A_over, b_over, transpose_a: true));
Assert.AreEqual(x_over.shape, (2, 1));
AssetSequenceEqual(x_over.ToArray<float>(), x.ToArray<float>());

var x_under = tf.linalg.lstsq(A_under, b_under);
var y = tf.matmul(A_under, tf.matmul(tf.linalg.inv(tf.matmul(A_under, A_under, transpose_b: true)), b_under), transpose_a: true);

Assert.AreEqual(x_under.shape, (4, 1));
AssetSequenceEqual(x_under.ToArray<float>(), y.ToArray<float>());

var x_over_reg = tf.linalg.lstsq(A_over, b_over, l2_regularizer: 2.0f);
var x_under_reg = tf.linalg.lstsq(A_under, b_under, l2_regularizer: 2.0f);
Assert.AreEqual(x_under_reg.shape, (4, 1));
AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });
}
}
}

+ 23
- 0
test/TensorFlowNET.UnitTest/NumPy/LinearAlgebra.Test.cs View File

@@ -0,0 +1,23 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.NumPy;

namespace TensorFlowNET.UnitTest.NumPy
{
/// <summary>
/// https://numpy.org/doc/stable/reference/generated/numpy.prod.html
/// </summary>
[TestClass]
public class LinearAlgebraTest : EagerModeTestBase
{
[TestMethod]
public void lstsq()
{

}
}
}

+ 14
- 0
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -62,6 +62,20 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}");
}

public void AssetSequenceEqual(float[] expected, float[] actual)
{
float eps = 1e-5f;
for (int i = 0; i < expected.Length; i++)
Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}");
}

public void AssetSequenceEqual(double[] expected, double[] actual)
{
double eps = 1e-5f;
for (int i = 0; i < expected.Length; i++)
Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}");
}

public void assertEqual(object given, object expected)
{
/*if (given is NDArray && expected is NDArray)


Loading…
Cancel
Save