Browse Source

add tf.split problem describe

tags/v0.20
pepure Haiping 5 years ago
parent
commit
47ba2ae6fd
2 changed files with 24 additions and 0 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  2. +14
    -0
      test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs

+ 10
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -485,6 +485,16 @@ namespace Tensorflow

public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Split", name,
null,
axis, value, num_split);

return results;
}

var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}


+ 14
- 0
test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs View File

@@ -52,5 +52,19 @@ namespace Tensorflow.UnitTest.TF_API
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
}
[TestMethod]
public void SplitTest()
{
var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } });
var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } });
var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });

var concatValue = tf.concat(new[] { a, b, c }, axis: 0);

var splitValue = tf.split(concatValue, 3, axis: new Tensor(0));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape));

}

}
}

Loading…
Cancel
Save