Browse Source

add tf.linalg.eye.

tags/v0.20
Oceania2018 5 years ago
parent
commit
388b64510f
7 changed files with 205 additions and 32 deletions
  1. +24
    -1
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  2. +40
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +43
    -0
      src/TensorFlowNET.Core/Operations/linalg_ops.cs
  4. +40
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs
  5. +32
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs
  6. +2
    -31
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  7. +24
    -0
      test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs

+ 24
- 1
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -18,10 +18,33 @@ namespace Tensorflow
{
public partial class tensorflow
{
public LinalgApi linalg { get; } = new LinalgApi();

public class LinalgApi
{
linalg_ops ops = new linalg_ops();

public Tensor eye(int num_rows,
int num_columns = -1,
TensorShape batch_shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
string name = null)
=> ops.eye(num_rows, num_columns: num_columns, batch_shape: batch_shape, dtype: dtype, name: name);

public Tensor diag(Tensor diagonal, string name = null)
=> gen_array_ops.diag(diagonal, name: name);

public Tensor matmul(Tensor a, Tensor b)
=> math_ops.matmul(a, b);

public Tensor batch_matmul(Tensor x, Tensor y)
=> gen_math_ops.batch_mat_mul(x, y);
}

public Tensor diag(Tensor diagonal, string name = null)
=> gen_array_ops.diag(diagonal, name: name);

public Tensor matmul(Tensor a, Tensor b)
public Tensor matmul(Tensor a, Tensor b)
=> math_ops.matmul(a, b);

public Tensor batch_matmul(Tensor x, Tensor y)


+ 40
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -599,6 +599,46 @@ namespace Tensorflow
public static Tensor invert_permutation(Tensor x, string name = null)
=> gen_array_ops.invert_permutation(x, name: name);

public static Tensor matrix_diag(Tensor diagonal,
string name = "diag",
int k = 0,
int num_rows = -1,
int num_cols = -1,
float padding_value = 0,
string align = "RIGHT_LEFT")
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"MatrixDiagV3", name,
null,
diagonal, k, num_rows, num_cols, padding_value,
"align", align);
return results[0];
}

throw new NotImplementedException("");
}

public static Tensor matrix_set_diag(Tensor input,
Tensor diagonal,
string name = "set_diag",
int k = 0,
string align = "RIGHT_LEFT")
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"MatrixSetDiagV3", name,
null,
input, diagonal, k,
"align", align);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Computes the shape of a broadcast given symbolic shapes.
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of


+ 43
- 0
src/TensorFlowNET.Core/Operations/linalg_ops.cs View File

@@ -0,0 +1,43 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class linalg_ops
{
public Tensor eye(int num_rows,
int num_columns = -1,
TensorShape batch_shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
string name = null)
{
return tf_with(ops.name_scope(name, default_name: "eye", new { num_rows, num_columns, batch_shape }), scope =>
{
if (num_columns == -1)
num_columns = num_rows;

bool is_square = num_columns == num_rows;
var diag_size = Math.Min(num_rows, num_columns);
if (batch_shape == null)
batch_shape = new TensorShape(new int[0]);
var diag_shape = batch_shape.dims.concat(new[] { diag_size });

int[] shape = null;
if (!is_square)
shape = batch_shape.dims.concat(new[] { num_rows, num_columns });

var diag_ones = array_ops.ones(diag_shape, dtype: dtype);
if (is_square)
return array_ops.matrix_diag(diag_ones);
else
{
var zero_matrix = array_ops.zeros(shape, dtype: dtype);
return array_ops.matrix_set_diag(zero_matrix, diag_ones);
}
});
}
}
}

+ 40
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs View File

@@ -0,0 +1,40 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class TensorShape
{
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[])shape.Dimensions.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((int[])shape.dims.Clone());

public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);

public static explicit operator int(TensorShape shape) => shape.size;
public static implicit operator TensorShape(int dim) => new TensorShape(dim);

public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);

public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);

public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);

public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);

public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);

public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);

public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
}
}

+ 32
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs View File

@@ -0,0 +1,32 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public partial class TensorShape
{
public override bool Equals(Object obj)
{
switch (obj)
{
case TensorShape shape1:
return Enumerable.SequenceEqual(shape1.dims, dims);
default:
return false;
}
}

/*public static bool operator ==(TensorShape shape1, TensorShape shape2)
{
return false;
}

public static bool operator !=(TensorShape shape1, TensorShape shape2)
{
return false;
}*/
}
}

+ 2
- 31
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -1,6 +1,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
@@ -12,7 +13,7 @@ namespace Tensorflow
/// Represents the shape of a `Tensor`.
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks>
public class TensorShape
public partial class TensorShape
{
private readonly Shape shape;

@@ -255,35 +256,5 @@ namespace Tensorflow
{
return shape.ToString();
}

public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);

public static explicit operator int(TensorShape shape) => shape.size;
public static implicit operator TensorShape(int dim) => new TensorShape(dim);

public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);

public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);

public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);

public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);

public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
}
}

+ 24
- 0
test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs View File

@@ -0,0 +1,24 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.UnitTest.TF_API
{
[TestClass]
public class LinalgTest
{
[TestMethod]
public void EyeTest()
{
var tensor = tf.linalg.eye(3);

Assert.AreEqual((3, 3), tensor.TensorShape);

Assert.AreEqual(0.0f, (float)tensor[2, 0]);
Assert.AreEqual(0.0f, (float)tensor[2, 1]);
Assert.AreEqual(1.0f, (float)tensor[2, 2]);
}
}
}

Loading…
Cancel
Save