@@ -26,6 +26,9 @@ namespace Tensorflow | |||||
{ | { | ||||
linalg_ops ops = new linalg_ops(); | linalg_ops ops = new linalg_ops(); | ||||
public Tensor einsum(string equation, Tensors inputs, string name = null) | |||||
=> math_ops.einsum(equation, inputs, name: name); | |||||
public Tensor eye(int num_rows, | public Tensor eye(int num_rows, | ||||
int num_columns = -1, | int num_columns = -1, | ||||
Shape batch_shape = null, | Shape batch_shape = null, | ||||
@@ -11,7 +11,11 @@ namespace Tensorflow | |||||
public Func<Operation, object> GetGradientAttrs { get; set; } | public Func<Operation, object> GetGradientAttrs { get; set; } | ||||
public object[] OpInputArgs { get; set; } | public object[] OpInputArgs { get; set; } | ||||
public Dictionary<string, object> OpAttrs { get; set; } | public Dictionary<string, object> OpAttrs { get; set; } | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="inputArgs">For array: OpInputArgs = new object[]{ }</param> | |||||
[DebuggerStepThrough] | [DebuggerStepThrough] | ||||
public ExecuteOpArgs(params object[] inputArgs) | public ExecuteOpArgs(params object[] inputArgs) | ||||
{ | { | ||||
@@ -230,6 +230,27 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor einsum(string equation, Tensors inputs, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "einsum", inputs), scope => | |||||
{ | |||||
name = scope; | |||||
return tf.Context.ExecuteOp("Einsum", name, new ExecuteOpArgs | |||||
{ | |||||
OpInputArgs = new object[] { inputs.ToArray() }, | |||||
GetGradientAttrs = (op) => new | |||||
{ | |||||
equation = op.get_attr<string>("equation"), | |||||
N = op.get_attr<int>("N"), | |||||
T = op.get_attr<TF_DataType>("T") | |||||
} | |||||
}.SetAttributes(new | |||||
{ | |||||
equation = equation | |||||
})); | |||||
}); | |||||
} | |||||
public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
=> gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name); | => gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name); | ||||
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
@@ -45,5 +45,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
Assert.AreEqual(x_under_reg.shape, (4, 1)); | Assert.AreEqual(x_under_reg.shape, (4, 1)); | ||||
AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/ | AssetSequenceEqual(x_under_reg.ToArray<float>(), new float[] { -0.04763567f, -1.214508f, 0.62748903f, 1.299031f });*/ | ||||
} | } | ||||
[TestMethod] | |||||
public void Einsum() | |||||
{ | |||||
var m0 = tf.random.normal((2, 3)); | |||||
var m1 = tf.random.normal((3, 5)); | |||||
var e = tf.linalg.einsum("ij,jk->ik", (m0, m1)); | |||||
Assert.AreEqual(e.shape, (2, 5)); | |||||
} | |||||
} | } | ||||
} | } |