You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flatten.cs 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. using System;
  2. using FluentAssertions;
  3. using Microsoft.VisualStudio.TestTools.UnitTesting;
  4. using NumSharp;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.layers_test
  8. {
  9. [TestClass]
  10. public class flatten
  11. {
  12. [TestMethod]
  13. public void Case1()
  14. {
  15. var sess = tf.Session().as_default();
  16. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2));
  17. sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
  18. }
  19. [TestMethod]
  20. public void Case2()
  21. {
  22. var sess = tf.Session().as_default();
  23. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
  24. sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1);
  25. }
  26. [TestMethod]
  27. public void Case3()
  28. {
  29. var sess = tf.Session().as_default();
  30. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
  31. new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
  32. }
  33. [TestMethod]
  34. public void Case4()
  35. {
  36. var sess = tf.Session().as_default();
  37. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
  38. sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
  39. }
  40. [TestMethod]
  41. public void Case5()
  42. {
  43. var sess = tf.Session().as_default();
  44. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
  45. sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
  46. }
  47. }
  48. }