You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat_op_test.cs 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using Tensorflow.NumPy;
  4. using TensorFlow;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. namespace TensorFlow.Kernel.UnitTest
  8. {
  9. [TestClass]
  10. public class concat_op_test
  11. {
  12. [TestMethod]
  13. public void testConcatEmpty()
  14. {
  15. var t1 = tf.constant(new int[] { });
  16. var t2 = tf.constant(new int[] { });
  17. var c = array_ops.concat(new[] { t1, t2 }, 0);
  18. var expected = np.array(new int[] { });
  19. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>()));
  20. }
  21. [TestMethod]
  22. public void testConcatNegativeAxis()
  23. {
  24. var t1 = tf.constant(new int[,] {{ 1, 2, 3 }, { 4, 5, 6 } });
  25. var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } });
  26. var c = array_ops.concat(new[] { t1, t2 }, -2);
  27. var expected = np.array(new int[,,] { { { 1, 2, 3 }, { 4, 5, 6 } }, { { 7, 8, 9 }, { 10, 11, 12 } } });
  28. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>()));
  29. c = array_ops.concat(new[] { t1, t2 }, -1);
  30. expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
  31. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>()));
  32. }
  33. [TestMethod]
  34. [DataRow(TF_DataType.TF_INT32)]
  35. [DataRow(TF_DataType.TF_INT64)]
  36. [DataRow(TF_DataType.TF_UINT32)]
  37. [DataRow(TF_DataType.TF_UINT64)]
  38. public void testConcatDtype(TF_DataType dtype)
  39. {
  40. var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }, dtype: dtype);
  41. var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }, dtype: dtype);
  42. var c = array_ops.concat(new[] { t1, t2 }, 1);
  43. var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
  44. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray<int>()));
  45. }
  46. [TestMethod]
  47. [DataRow(TF_DataType.TF_INT32)]
  48. [DataRow(TF_DataType.TF_INT64)]
  49. public void testConcatAxisType(TF_DataType dtype)
  50. {
  51. var t1 = tf.constant(new int[,] { { 1, 2, 3 }, {4, 5, 6 } });
  52. var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } });
  53. var c = array_ops.concat(new[] { t1, t2 }, tf.constant(1, dtype: dtype));
  54. var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
  55. Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray<int>()));
  56. }
  57. }
  58. }