You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flatten.cs 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. using System;
  2. using FluentAssertions;
  3. using Microsoft.VisualStudio.TestTools.UnitTesting;
  4. using NumSharp;
  5. using Tensorflow;
  6. using Tensorflow.UnitTest;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.layers_test
  9. {
  10. [TestClass]
  11. public class flatten : GraphModeTestBase
  12. {
  13. [TestMethod]
  14. public void Case1()
  15. {
  16. var sess = tf.Session().as_default();
  17. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2));
  18. sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
  19. }
  20. [TestMethod]
  21. public void Case2()
  22. {
  23. var sess = tf.Session().as_default();
  24. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
  25. sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1);
  26. }
  27. [TestMethod]
  28. public void Case3()
  29. {
  30. var sess = tf.Session().as_default();
  31. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
  32. new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
  33. }
  34. [TestMethod]
  35. public void Case4()
  36. {
  37. var sess = tf.Session().as_default();
  38. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
  39. sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
  40. }
  41. [TestMethod]
  42. public void Case5()
  43. {
  44. var sess = tf.Session().as_default();
  45. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
  46. sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
  47. }
  48. }
  49. }