You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat_op_test.cs 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using Tensorflow.NumPy;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlow.Kernel.UnitTest
  6. {
  7. [TestClass]
  8. public class concat_op_test
  9. {
  10. [TestMethod]
  11. public void testConcatEmpty()
  12. {
  13. var t1 = tf.constant(new int[] { });
  14. var t2 = tf.constant(new int[] { });
  15. var c = array_ops.concat(new[] { t1, t2 }, 0);
  16. var expected = np.array(new int[] { });
  17. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>()));
  18. }
  19. [TestMethod]
  20. public void testConcatNegativeAxis()
  21. {
  22. var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } });
  23. var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } });
  24. var c = array_ops.concat(new[] { t1, t2 }, -2);
  25. var expected = np.array(new int[,,] { { { 1, 2, 3 }, { 4, 5, 6 } }, { { 7, 8, 9 }, { 10, 11, 12 } } });
  26. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>()));
  27. c = array_ops.concat(new[] { t1, t2 }, -1);
  28. expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
  29. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>()));
  30. }
  31. [TestMethod]
  32. [DataRow(TF_DataType.TF_INT32)]
  33. [DataRow(TF_DataType.TF_INT64)]
  34. [DataRow(TF_DataType.TF_UINT32)]
  35. [DataRow(TF_DataType.TF_UINT64)]
  36. public void testConcatDtype(TF_DataType dtype)
  37. {
  38. var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }, dtype: dtype);
  39. var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }, dtype: dtype);
  40. var c = array_ops.concat(new[] { t1, t2 }, 1);
  41. var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
  42. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray<int>()));
  43. }
  44. [TestMethod]
  45. [DataRow(TF_DataType.TF_INT32)]
  46. [DataRow(TF_DataType.TF_INT64)]
  47. public void testConcatAxisType(TF_DataType dtype)
  48. {
  49. var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } });
  50. var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } });
  51. var c = array_ops.concat(new[] { t1, t2 }, tf.constant(1, dtype: dtype));
  52. var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
  53. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray<int>()));
  54. }
  55. }
  56. }