You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PoolingTest.cs 9.7 kB


  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. namespace TensorFlowNET.Keras.UnitTest
  8. {
  9. /// <summary>
  10. /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
  11. /// </summary>
  12. [TestClass]
  13. public class PoolingTest : EagerModeTestBase
  14. {
  15. private NDArray input_array_1D = np.array(new float[,,]
  16. {
  17. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  18. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  19. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  20. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  21. });
  22. private NDArray input_array_2D = np.array(new float[,,,]
  23. {{
  24. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  25. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  26. },{
  27. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  28. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  29. },{
  30. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  31. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  32. },{
  33. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  34. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  35. }});
  36. [TestMethod]
  37. public void GlobalAverage1DPoolingChannelsLast()
  38. {
  39. var pool = keras.layers.GlobalAveragePooling1D();
  40. var y = pool.Apply(input_array_1D);
  41. Assert.AreEqual(4, y.shape[0]);
  42. Assert.AreEqual(5, y.shape[1]);
  43. var expected = np.array(new float[,]
  44. {
  45. {1,2,3,3,3},
  46. {4,5,6,3,3},
  47. {7,8,9,3,3},
  48. {7,8,9,3,3}
  49. });
  50. Assert.AreEqual(expected, y[0].numpy());
  51. }
  52. [TestMethod]
  53. public void GlobalAverage1DPoolingChannelsFirst()
  54. {
  55. var pool = keras.layers.GlobalAveragePooling1D(data_format: "channels_first");
  56. var y = pool.Apply(input_array_1D);
  57. Assert.AreEqual(4, y.shape[0]);
  58. Assert.AreEqual(3, y.shape[1]);
  59. var expected = np.array(new float[,]
  60. {
  61. {2.4f, 2.4f, 2.4f},
  62. {4.2f, 4.2f, 4.2f},
  63. {6.0f, 6.0f, 6.0f},
  64. {6.0f, 6.0f, 6.0f}
  65. });
  66. Assert.AreEqual(expected, y[0].numpy());
  67. }
  68. [TestMethod]
  69. public void GlobalAverage2DPoolingChannelsLast()
  70. {
  71. var pool = keras.layers.GlobalAveragePooling2D();
  72. var y = pool.Apply(input_array_2D);
  73. Assert.AreEqual(4, y.shape[0]);
  74. Assert.AreEqual(5, y.shape[1]);
  75. var expected = np.array(new float[,]
  76. {
  77. {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
  78. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  79. {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
  80. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
  81. });
  82. Assert.AreEqual(expected, y[0].numpy());
  83. }
  84. [TestMethod]
  85. public void GlobalAverage2DPoolingChannelsFirst()
  86. {
  87. var pool = keras.layers.GlobalAveragePooling2D(data_format: "channels_first");
  88. var y = pool.Apply(input_array_2D);
  89. Assert.AreEqual(4, y.shape[0]);
  90. Assert.AreEqual(2, y.shape[1]);
  91. var expected = np.array(new float[,]
  92. {
  93. {2.4f, 4.2f},
  94. {6.0f, 6.0f},
  95. {2.4f, 4.2f},
  96. {6.0f, 6.0f}
  97. });
  98. Assert.AreEqual(expected, y[0].numpy());
  99. }
  100. [TestMethod]
  101. public void GlobalMax1DPoolingChannelsLast()
  102. {
  103. var pool = keras.layers.GlobalMaxPooling1D();
  104. var y = pool.Apply(input_array_1D);
  105. Assert.AreEqual(4, y.shape[0]);
  106. Assert.AreEqual(5, y.shape[1]);
  107. var expected = np.array(new float[,]
  108. {
  109. {1,2,3,3,3},
  110. {4,5,6,3,3},
  111. {7,8,9,3,3},
  112. {7,8,9,3,3}
  113. });
  114. Assert.AreEqual(expected, y[0].numpy());
  115. }
  116. [TestMethod]
  117. public void GlobalMax1DPoolingChannelsFirst()
  118. {
  119. var pool = keras.layers.GlobalMaxPooling1D(data_format: "channels_first");
  120. var y = pool.Apply(input_array_1D);
  121. Assert.AreEqual(4, y.shape[0]);
  122. Assert.AreEqual(3, y.shape[1]);
  123. var expected = np.array(new float[,]
  124. {
  125. {3.0f, 3.0f, 3.0f},
  126. {6.0f, 6.0f, 6.0f},
  127. {9.0f, 9.0f, 9.0f},
  128. {9.0f, 9.0f, 9.0f}
  129. });
  130. Assert.AreEqual(expected, y[0].numpy());
  131. }
  132. [TestMethod]
  133. public void GlobalMax2DPoolingChannelsLast()
  134. {
  135. var input_array_2D = np.array(new float[,,,]
  136. {{
  137. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  138. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  139. },{
  140. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  141. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  142. },{
  143. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  144. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  145. },{
  146. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  147. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  148. }});
  149. var pool = keras.layers.GlobalMaxPooling2D();
  150. var y = pool.Apply(input_array_2D);
  151. Assert.AreEqual(4, y.shape[0]);
  152. Assert.AreEqual(5, y.shape[1]);
  153. var expected = np.array(new float[,]
  154. {
  155. {4.0f, 5.0f, 6.0f, 9.0f, 3.0f},
  156. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  157. {4.0f, 5.0f, 6.0f, 3.0f, 9.0f},
  158. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
  159. });
  160. Assert.AreEqual(expected, y[0].numpy());
  161. }
  162. [TestMethod]
  163. public void GlobalMax2DPoolingChannelsFirst()
  164. {
  165. var input_array_2D = np.array(new float[,,,]
  166. {{
  167. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  168. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  169. },{
  170. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  171. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  172. },{
  173. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  174. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  175. },{
  176. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  177. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  178. }});
  179. var pool = keras.layers.GlobalMaxPooling2D(data_format: "channels_first");
  180. var y = pool.Apply(input_array_2D);
  181. Assert.AreEqual(4, y.shape[0]);
  182. Assert.AreEqual(2, y.shape[1]);
  183. var expected = np.array(new float[,]
  184. {
  185. {9.0f, 6.0f},
  186. {9.0f, 9.0f},
  187. {9.0f, 6.0f},
  188. {9.0f, 9.0f}
  189. });
  190. Assert.AreEqual(expected, y[0].numpy());
  191. }
  192. [TestMethod, Ignore("There's an error generated from TF complaining about the shape of the pool. Needs further investigation.")]
  193. public void Max1DPoolingChannelsLast()
  194. {
  195. var x = input_array_1D;
  196. var pool = keras.layers.MaxPooling1D(pool_size:2, strides:1);
  197. var y = pool.Apply(x);
  198. Assert.AreEqual(4, y.shape[0]);
  199. Assert.AreEqual(2, y.shape[1]);
  200. Assert.AreEqual(5, y.shape[2]);
  201. var expected = np.array(new float[,,]
  202. {
  203. {{2.0f, 2.0f, 3.0f, 3.0f, 3.0f},
  204. { 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}},
  205. {{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  206. {4.0f, 5.0f, 6.0f, 3.0f, 3.0f}},
  207. {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  208. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}},
  209. {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  210. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}
  211. });
  212. Assert.AreEqual(expected, y[0].numpy());
  213. }
  214. [TestMethod]
  215. public void Max2DPoolingChannelsLast()
  216. {
  217. var x = np.array(new float[,,,]
  218. {{
  219. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  220. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  221. },{
  222. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  223. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  224. },{
  225. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  226. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  227. },{
  228. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  229. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  230. }});
  231. var pool = keras.layers.MaxPooling2D(pool_size: 2, strides: 1);
  232. var y = pool.Apply(x);
  233. Assert.AreEqual(4, y.shape[0]);
  234. Assert.AreEqual(1, y.shape[1]);
  235. Assert.AreEqual(2, y.shape[2]);
  236. Assert.AreEqual(5, y.shape[3]);
  237. var expected = np.array(new float[,,,]
  238. {
  239. {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  240. {4.0f, 5.0f, 6.0f, 9.0f, 3.0f}}},
  241. {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  242. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}},
  243. {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  244. {4.0f, 5.0f, 6.0f, 3.0f, 9.0f}}},
  245. {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  246. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}}
  247. });
  248. Assert.AreEqual(expected, y[0].numpy());
  249. }
  250. }
  251. }