diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 35284eb2..1bf03e7f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -485,6 +485,16 @@ namespace Tensorflow public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Split", name, + null, + axis, value, num_split); + + return results; + } + var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); return _op.outputs; } diff --git a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs index 06448e5e..d2bd19ca 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs @@ -52,5 +52,19 @@ namespace Tensorflow.UnitTest.TF_API var concatValue = tf.concat(new[] { a, b, c }, axis: 0); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); } + [TestMethod] + public void SplitTest() + { + var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); + var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); + var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); + + var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + + var splitValue = tf.split(concatValue, 3, axis: new Tensor(0)); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); + + } + } }