@@ -18,10 +18,33 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public LinalgApi linalg { get; } = new LinalgApi(); | |||
public class LinalgApi | |||
{ | |||
linalg_ops ops = new linalg_ops(); | |||
public Tensor eye(int num_rows, | |||
int num_columns = -1, | |||
TensorShape batch_shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
string name = null) | |||
=> ops.eye(num_rows, num_columns: num_columns, batch_shape: batch_shape, dtype: dtype, name: name); | |||
public Tensor diag(Tensor diagonal, string name = null) | |||
=> gen_array_ops.diag(diagonal, name: name); | |||
public Tensor matmul(Tensor a, Tensor b) | |||
=> math_ops.matmul(a, b); | |||
public Tensor batch_matmul(Tensor x, Tensor y) | |||
=> gen_math_ops.batch_mat_mul(x, y); | |||
} | |||
public Tensor diag(Tensor diagonal, string name = null) | |||
=> gen_array_ops.diag(diagonal, name: name); | |||
public Tensor matmul(Tensor a, Tensor b) | |||
public Tensor matmul(Tensor a, Tensor b) | |||
=> math_ops.matmul(a, b); | |||
public Tensor batch_matmul(Tensor x, Tensor y) | |||
@@ -599,6 +599,46 @@ namespace Tensorflow | |||
public static Tensor invert_permutation(Tensor x, string name = null) | |||
=> gen_array_ops.invert_permutation(x, name: name); | |||
public static Tensor matrix_diag(Tensor diagonal, | |||
string name = "diag", | |||
int k = 0, | |||
int num_rows = -1, | |||
int num_cols = -1, | |||
float padding_value = 0, | |||
string align = "RIGHT_LEFT") | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
"MatrixDiagV3", name, | |||
null, | |||
diagonal, k, num_rows, num_cols, padding_value, | |||
"align", align); | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
public static Tensor matrix_set_diag(Tensor input, | |||
Tensor diagonal, | |||
string name = "set_diag", | |||
int k = 0, | |||
string align = "RIGHT_LEFT") | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
"MatrixSetDiagV3", name, | |||
null, | |||
input, diagonal, k, | |||
"align", align); | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
/// <summary> | |||
/// Computes the shape of a broadcast given symbolic shapes. | |||
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of | |||
@@ -0,0 +1,43 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class linalg_ops | |||
{ | |||
public Tensor eye(int num_rows, | |||
int num_columns = -1, | |||
TensorShape batch_shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, default_name: "eye", new { num_rows, num_columns, batch_shape }), scope => | |||
{ | |||
if (num_columns == -1) | |||
num_columns = num_rows; | |||
bool is_square = num_columns == num_rows; | |||
var diag_size = Math.Min(num_rows, num_columns); | |||
if (batch_shape == null) | |||
batch_shape = new TensorShape(new int[0]); | |||
var diag_shape = batch_shape.dims.concat(new[] { diag_size }); | |||
int[] shape = null; | |||
if (!is_square) | |||
shape = batch_shape.dims.concat(new[] { num_rows, num_columns }); | |||
var diag_ones = array_ops.ones(diag_shape, dtype: dtype); | |||
if (is_square) | |||
return array_ops.matrix_diag(diag_ones); | |||
else | |||
{ | |||
var zero_matrix = array_ops.zeros(shape, dtype: dtype); | |||
return array_ops.matrix_set_diag(zero_matrix, diag_ones); | |||
} | |||
}); | |||
} | |||
} | |||
} |
@@ -0,0 +1,40 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class TensorShape | |||
{ | |||
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[])shape.Dimensions.Clone()); | |||
public static implicit operator Shape(TensorShape shape) => new Shape((int[])shape.dims.Clone()); | |||
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | |||
public static explicit operator int(TensorShape shape) => shape.size; | |||
public static implicit operator TensorShape(int dim) => new TensorShape(dim); | |||
public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||
public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | |||
public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||
public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||
public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); | |||
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); | |||
} | |||
} |
@@ -0,0 +1,32 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class TensorShape | |||
{ | |||
public override bool Equals(Object obj) | |||
{ | |||
switch (obj) | |||
{ | |||
case TensorShape shape1: | |||
return Enumerable.SequenceEqual(shape1.dims, dims); | |||
default: | |||
return false; | |||
} | |||
} | |||
/*public static bool operator ==(TensorShape shape1, TensorShape shape2) | |||
{ | |||
return false; | |||
} | |||
public static bool operator !=(TensorShape shape1, TensorShape shape2) | |||
{ | |||
return false; | |||
}*/ | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.ComponentModel; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Linq; | |||
using System.Runtime.CompilerServices; | |||
@@ -12,7 +13,7 @@ namespace Tensorflow | |||
/// Represents the shape of a `Tensor`. | |||
/// </summary> | |||
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||
public class TensorShape | |||
public partial class TensorShape | |||
{ | |||
private readonly Shape shape; | |||
@@ -255,35 +256,5 @@ namespace Tensorflow | |||
{ | |||
return shape.ToString(); | |||
} | |||
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | |||
public static explicit operator int(TensorShape shape) => shape.size; | |||
public static implicit operator TensorShape(int dim) => new TensorShape(dim); | |||
public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||
public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | |||
public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||
public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||
public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); | |||
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); | |||
} | |||
} |
@@ -0,0 +1,24 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.UnitTest.TF_API | |||
{ | |||
[TestClass] | |||
public class LinalgTest | |||
{ | |||
[TestMethod] | |||
public void EyeTest() | |||
{ | |||
var tensor = tf.linalg.eye(3); | |||
Assert.AreEqual((3, 3), tensor.TensorShape); | |||
Assert.AreEqual(0.0f, (float)tensor[2, 0]); | |||
Assert.AreEqual(0.0f, (float)tensor[2, 1]); | |||
Assert.AreEqual(1.0f, (float)tensor[2, 2]); | |||
} | |||
} | |||
} |