Browse Source

tf.linalg.einsum #885

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
3072de5b4d
4 changed files with 38 additions and 1 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  2. +5
    -1
      src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs
  3. +21
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  4. +9
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -26,6 +26,9 @@ namespace Tensorflow
{
linalg_ops ops = new linalg_ops();

public Tensor einsum(string equation, Tensors inputs, string name = null)
=> math_ops.einsum(equation, inputs, name: name);

public Tensor eye(int num_rows,
int num_columns = -1,
Shape batch_shape = null,


+ 5
- 1
src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs View File

@@ -11,7 +11,11 @@ namespace Tensorflow
public Func<Operation, object> GetGradientAttrs { get; set; }
public object[] OpInputArgs { get; set; }
public Dictionary<string, object> OpAttrs { get; set; }

/// <summary>
///
/// </summary>
/// <param name="inputArgs">For array: OpInputArgs = new object[]{ }</param>
[DebuggerStepThrough]
public ExecuteOpArgs(params object[] inputArgs)
{


+ 21
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -230,6 +230,27 @@ namespace Tensorflow
});
}

public static Tensor einsum(string equation, Tensors inputs, string name = null)
{
return tf_with(ops.name_scope(name, "einsum", inputs), scope =>
{
name = scope;
return tf.Context.ExecuteOp("Einsum", name, new ExecuteOpArgs
{
OpInputArgs = new object[] { inputs.ToArray() },
GetGradientAttrs = (op) => new
{
equation = op.get_attr<string>("equation"),
N = op.get_attr<int>("N"),
T = op.get_attr<TF_DataType>("T")
}
}.SetAttributes(new
{
equation = equation
}));
});
}

public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name);
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)


+ 9
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs View File

@@ -45,5 +45,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreEqual(x_under_reg.shape, (4, 1));
AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/
}

[TestMethod]
public void Einsum()
{
var m0 = tf.random.normal((2, 3));
var m1 = tf.random.normal((3, 5));
var e = tf.linalg.einsum("ij,jk->ik", (m0, m1));
Assert.AreEqual(e.shape, (2, 5));
}
}
}

Loading…
Cancel
Save