diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 8574b838..be614294 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -152,7 +152,7 @@ namespace Tensorflow
///
///
///
- public Tensor transpose(T1 a, Shape perm = null, string name = "transpose", bool conjugate = false)
+ public Tensor transpose(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
=> array_ops.transpose(a, perm, name, conjugate);
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index 7d4e418a..2b6051e0 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
+using Tensorflow.NumPy;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -40,13 +41,20 @@ namespace Tensorflow
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name);
+
+ public Tensor inv(Tensor input, bool adjoint = false, string name = null)
+ => ops.matrix_inverse(input, adjoint: adjoint, name: name);
+
+ public Tensor lstsq(Tensor matrix, Tensor rhs,
+ NDArray l2_regularizer = null, bool fast = true, string name = null)
+ => ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
}
public Tensor diag(Tensor diagonal, string name = null)
=> gen_array_ops.diag(diagonal, name: name);
- public Tensor matmul(Tensor a, Tensor b)
- => math_ops.matmul(a, b);
+ public Tensor matmul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
+ => math_ops.matmul(a, b, transpose_a: transpose_a, transpose_b: transpose_b);
///
/// Multiply slices of the two matrices "x" and "y".
diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs
index 45f05ed7..4c7b6488 100644
--- a/src/TensorFlowNET.Core/NumPy/Axis.cs
+++ b/src/TensorFlowNET.Core/NumPy/Axis.cs
@@ -50,6 +50,9 @@ namespace Tensorflow
public static implicit operator Tensor(Axis axis)
=> constant_op.constant(axis);
+
+ public override string ToString()
+ => $"({string.Join(", ", axis)})";
}
}
diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs
new file mode 100644
index 00000000..92ef6b69
--- /dev/null
+++ b/src/TensorFlowNET.Core/NumPy/Implementation/LinearAlgebraImpl.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.NumPy
+{
+ public class LinearAlgebraImpl
+ {
+ public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn")
+ {
+ return a;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
index 515c3dcb..3b5e028a 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
@@ -48,7 +48,7 @@ namespace Tensorflow.NumPy
=> new NDArray(value);
public static implicit operator Tensor(NDArray nd)
- => nd._tensor;
+ => nd?._tensor;
public static implicit operator NDArray(Tensor tensor)
=> new NDArray(tensor);
diff --git a/src/TensorFlowNET.Core/Numpy/Numpy.cs b/src/TensorFlowNET.Core/Numpy/Numpy.cs
index 85cbeb71..7131b425 100644
--- a/src/TensorFlowNET.Core/Numpy/Numpy.cs
+++ b/src/TensorFlowNET.Core/Numpy/Numpy.cs
@@ -105,5 +105,7 @@ namespace Tensorflow.NumPy
{
throw new NotImplementedException("");
}
+
+ public static LinearAlgebraImpl linalg = new LinearAlgebraImpl();
}
}
diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs
index a1068215..263550e3 100644
--- a/src/TensorFlowNET.Core/Numpy/Shape.cs
+++ b/src/TensorFlowNET.Core/Numpy/Shape.cs
@@ -38,6 +38,16 @@ namespace Tensorflow
}
}
+ #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
+ public int Length => ndim;
+ public long[] Slice(int start, int length)
+ {
+ var slice = new long[length];
+ Array.Copy(_dims, start, slice, 0, length);
+ return slice;
+ }
+ #endregion
+
private Shape()
{
}
@@ -107,7 +117,7 @@ namespace Tensorflow
public long this[int n]
{
- get => dims[n];
+ get => n < 0 ? dims[ndim + n] : dims[n];
set => dims[n] = value;
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index 88bfb237..b1f7e41b 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -774,10 +774,10 @@ namespace Tensorflow
int k = 0,
int num_rows = -1,
int num_cols = -1,
- double padding_value = 0,
+ float padding_value = 0f,
string align = "RIGHT_LEFT")
=> tf.Context.ExecuteOp("MatrixDiagV3", name,
- new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value)
+ new ExecuteOpArgs(diagonal, k, num_rows, num_cols, ops.convert_to_tensor(padding_value, dtype: diagonal.dtype))
.SetAttributes(new { align }));
public static Tensor matrix_set_diag(Tensor input,
@@ -900,7 +900,7 @@ namespace Tensorflow
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
}
- public static Tensor transpose(T1 a, Shape perm, string name = "transpose", bool conjugate = false)
+ public static Tensor transpose(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
diff --git a/src/TensorFlowNET.Core/Operations/linalg_ops.cs b/src/TensorFlowNET.Core/Operations/linalg_ops.cs
index 33fbe953..6a0b869c 100644
--- a/src/TensorFlowNET.Core/Operations/linalg_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/linalg_ops.cs
@@ -20,11 +20,12 @@ namespace Tensorflow
var diag_size = Math.Min(num_rows, num_columns);
if (batch_shape == null)
batch_shape = new Shape(new int[0]);
- var diag_shape = batch_shape.dims.concat(new long[] { diag_size });
+ var batch_shape_tensor = ops.convert_to_tensor(batch_shape, dtype: tf.int32, name: "shape");
+ var diag_shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { diag_size }) }, axis: 0);
- long[] shape = null;
+ Tensor shape = null;
if (!is_square)
- shape = batch_shape.dims.concat(new long[] { num_rows, num_columns });
+ shape = array_ops.concat(new[] { batch_shape_tensor, tf.constant(new int[] { num_rows, num_columns }) }, axis: 0);
var diag_ones = array_ops.ones(diag_shape, dtype: dtype);
if (is_square)
@@ -36,5 +37,81 @@ namespace Tensorflow
}
});
}
+
+ public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
+ => tf.Context.ExecuteOp("MatrixInverse", name,
+ new ExecuteOpArgs(input).SetAttributes(new
+ {
+ adjoint
+ }));
+
+ public Tensor matrix_solve_ls(Tensor matrix, Tensor rhs,
+ Tensor l2_regularizer = null, bool fast = true, string name = null)
+ {
+ return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer);
+ }
+
+ Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
+ {
+ Shape matrix_shape = matrix.shape[^2..];
+ if (matrix_shape.IsFullyDefined)
+ {
+ if (matrix_shape[-2] >= matrix_shape[-1])
+ return _overdetermined(matrix, rhs, l2_regularizer);
+ else
+ return _underdetermined(matrix, rhs, l2_regularizer);
+ }
+
+ throw new NotImplementedException("");
+ }
+
+ Tensor _overdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
+ {
+ var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: true);
+ return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a: true));
+ }
+
+ Tensor _underdetermined(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null)
+ {
+ var chol = _RegularizedGramianCholesky(matrix, l2_regularizer: l2_regularizer, first_kind: false);
+ return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a: true);
+ }
+
+ Tensor _RegularizedGramianCholesky(Tensor matrix, Tensor l2_regularizer, bool first_kind)
+ {
+ var gramian = math_ops.matmul(matrix, matrix, adjoint_a: first_kind, adjoint_b: !first_kind);
+
+ if (l2_regularizer != null)
+ {
+ var matrix_shape = array_ops.shape(matrix);
+ var batch_shape = matrix_shape[":-2"];
+ var small_dim = first_kind ? matrix_shape[-1] : matrix_shape[-2];
+ var identity = eye(small_dim.numpy(), batch_shape: batch_shape.shape, dtype: matrix.dtype);
+ var small_dim_static = matrix.shape[first_kind ? -1 : -2];
+ identity.shape = matrix.shape[..^2].concat(new[] { small_dim_static, small_dim_static });
+ gramian += l2_regularizer * identity;
+ }
+
+ return cholesky(gramian);
+ }
+
+ public Tensor cholesky(Tensor input, string name = null)
+ => tf.Context.ExecuteOp("Cholesky", name, new ExecuteOpArgs(input));
+
+ public Tensor cholesky_solve(Tensor chol, Tensor rhs, string name = null)
+ => tf_with(ops.name_scope(name, default_name: "eye", new { chol, rhs }), scope =>
+ {
+ var y = matrix_triangular_solve(chol, rhs, adjoint: false, lower: true);
+ var x = matrix_triangular_solve(chol, y, adjoint: true, lower: true);
+ return x;
+ });
+
+ public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = true, bool adjoint = false, string name = null)
+ => tf.Context.ExecuteOp("MatrixTriangularSolve", name,
+ new ExecuteOpArgs(matrix, rhs).SetAttributes(new
+ {
+ lower,
+ adjoint
+ }));
}
}
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 84094a6f..4fb481da 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -791,6 +791,18 @@ namespace Tensorflow
if (transpose_b && adjoint_b)
throw new ValueError("Only one of transpose_b and adjoint_b can be True.");
+ if(adjoint_a)
+ {
+ a = conj(a);
+ transpose_a = true;
+ }
+
+ if (adjoint_b)
+ {
+ b = conj(b);
+ transpose_b = true;
+ }
+
result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name);
});
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 3c185cb4..fca4169c 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -103,6 +103,8 @@ namespace Tensorflow
public bool IsCreatedInGraphMode => isCreatedInGraphMode;
public bool IsSparseTensor => this is SparseTensor;
+ public Tensor TensorShape => tf.shape(this);
+
///
/// Returns the shape of a tensor.
///
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 5fa6bdd9..5f2d74bd 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -166,7 +166,7 @@ namespace Tensorflow
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name),
- Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
+ Shape ts => constant_op.constant(ts.size == 0 ? new long[0] : ts.dims, dtype: dtype, name: name),
string str => constant_op.constant(str, dtype: tf.@string, name: name),
string[] str => constant_op.constant(str, dtype: tf.@string, name: name),
IEnumerable