From 3072de5b4dd1ebcfbc935cc5bf973993876b36a6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 1 Dec 2021 20:28:54 -0600 Subject: [PATCH] tf.linalg.einsum #885 --- src/TensorFlowNET.Core/APIs/tf.linalg.cs | 3 +++ .../Contexts/ExecuteOpArgs.cs | 6 +++++- src/TensorFlowNET.Core/Operations/math_ops.cs | 21 +++++++++++++++++++ .../ManagedAPI/LinalgTest.cs | 9 ++++++++ 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index 1fef9c9e..f2749abc 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -26,6 +26,9 @@ namespace Tensorflow { linalg_ops ops = new linalg_ops(); + public Tensor einsum(string equation, Tensors inputs, string name = null) + => math_ops.einsum(equation, inputs, name: name); + public Tensor eye(int num_rows, int num_columns = -1, Shape batch_shape = null, diff --git a/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs index 8710ea5d..2e633760 100644 --- a/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs +++ b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs @@ -11,7 +11,11 @@ namespace Tensorflow public Func GetGradientAttrs { get; set; } public object[] OpInputArgs { get; set; } public Dictionary OpAttrs { get; set; } - + + /// + /// + /// + /// For array: OpInputArgs = new object[]{ } [DebuggerStepThrough] public ExecuteOpArgs(params object[] inputArgs) { diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index df960ad4..5657fafa 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -230,6 +230,27 @@ namespace Tensorflow }); } + public static Tensor einsum(string equation, Tensors inputs, string name = null) + { + return tf_with(ops.name_scope(name, "einsum", inputs), scope => + { + name = scope; + return tf.Context.ExecuteOp("Einsum", name, new ExecuteOpArgs + { + OpInputArgs = new object[] { inputs.ToArray() }, + GetGradientAttrs = (op) => new + { + equation = op.get_attr("equation"), + N = op.get_attr("N"), + T = op.get_attr("T") + } + }.SetAttributes(new + { + equation = equation + })); + }); + } + public static Tensor greater_equal(Tx x, Ty y, string name = null) => gen_math_ops.greater_equal(x, y, name: name); public static Tensor equal(Tx x, Ty y, string name = null) diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs index a953cce8..eefc1c47 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs @@ -45,5 +45,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI Assert.AreEqual(x_under_reg.shape, (4, 1)); AssetSequenceEqual(x_under_reg.ToArray(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/ } + + [TestMethod] + public void Einsum() + { + var m0 = tf.random.normal((2, 3)); + var m1 = tf.random.normal((3, 5)); + var e = tf.linalg.einsum("ij,jk->ik", (m0, m1)); + Assert.AreEqual(e.shape, (2, 5)); + } } }