You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PoolingTest.cs 9.5 kB


  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using static Tensorflow.KerasApi;
  4. namespace Tensorflow.Keras.UnitTest.Layers
  5. {
  6. /// <summary>
  7. /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
  8. /// </summary>
  9. [TestClass]
  10. public class PoolingTest : EagerModeTestBase
  11. {
  12. private NDArray input_array_1D = np.array(new float[,,]
  13. {
  14. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  15. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  16. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  17. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  18. });
  19. private NDArray input_array_2D = np.array(new float[,,,]
  20. {{
  21. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  22. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  23. },{
  24. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  25. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  26. },{
  27. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  28. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  29. },{
  30. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  31. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  32. }});
  33. [TestMethod]
  34. public void GlobalAverage1DPoolingChannelsLast()
  35. {
  36. var pool = keras.layers.GlobalAveragePooling1D();
  37. var y = pool.Apply(input_array_1D);
  38. Assert.AreEqual(4, y.shape[0]);
  39. Assert.AreEqual(5, y.shape[1]);
  40. var expected = np.array(new float[,]
  41. {
  42. {1,2,3,3,3},
  43. {4,5,6,3,3},
  44. {7,8,9,3,3},
  45. {7,8,9,3,3}
  46. });
  47. Assert.AreEqual(expected, y[0].numpy());
  48. }
  49. [TestMethod]
  50. public void GlobalAverage1DPoolingChannelsFirst()
  51. {
  52. var pool = keras.layers.GlobalAveragePooling1D(data_format: "channels_first");
  53. var y = pool.Apply(input_array_1D);
  54. Assert.AreEqual(4, y.shape[0]);
  55. Assert.AreEqual(3, y.shape[1]);
  56. var expected = np.array(new float[,]
  57. {
  58. {2.4f, 2.4f, 2.4f},
  59. {4.2f, 4.2f, 4.2f},
  60. {6.0f, 6.0f, 6.0f},
  61. {6.0f, 6.0f, 6.0f}
  62. });
  63. Assert.AreEqual(expected, y[0].numpy());
  64. }
  65. [TestMethod]
  66. public void GlobalAverage2DPoolingChannelsLast()
  67. {
  68. var pool = keras.layers.GlobalAveragePooling2D();
  69. var y = pool.Apply(input_array_2D);
  70. Assert.AreEqual(4, y.shape[0]);
  71. Assert.AreEqual(5, y.shape[1]);
  72. var expected = np.array(new float[,]
  73. {
  74. {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
  75. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  76. {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
  77. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
  78. });
  79. Assert.AreEqual(expected, y[0].numpy());
  80. }
  81. [TestMethod]
  82. public void GlobalAverage2DPoolingChannelsFirst()
  83. {
  84. var pool = keras.layers.GlobalAveragePooling2D(data_format: "channels_first");
  85. var y = pool.Apply(input_array_2D);
  86. Assert.AreEqual(4, y.shape[0]);
  87. Assert.AreEqual(2, y.shape[1]);
  88. var expected = np.array(new float[,]
  89. {
  90. {2.4f, 4.2f},
  91. {6.0f, 6.0f},
  92. {2.4f, 4.2f},
  93. {6.0f, 6.0f}
  94. });
  95. Assert.AreEqual(expected, y[0].numpy());
  96. }
  97. [TestMethod]
  98. public void GlobalMax1DPoolingChannelsLast()
  99. {
  100. var pool = keras.layers.GlobalMaxPooling1D();
  101. var y = pool.Apply(input_array_1D);
  102. Assert.AreEqual(4, y.shape[0]);
  103. Assert.AreEqual(5, y.shape[1]);
  104. var expected = np.array(new float[,]
  105. {
  106. {1,2,3,3,3},
  107. {4,5,6,3,3},
  108. {7,8,9,3,3},
  109. {7,8,9,3,3}
  110. });
  111. Assert.AreEqual(expected, y[0].numpy());
  112. }
  113. [TestMethod]
  114. public void GlobalMax1DPoolingChannelsFirst()
  115. {
  116. var pool = keras.layers.GlobalMaxPooling1D(data_format: "channels_first");
  117. var y = pool.Apply(input_array_1D);
  118. Assert.AreEqual(4, y.shape[0]);
  119. Assert.AreEqual(3, y.shape[1]);
  120. var expected = np.array(new float[,]
  121. {
  122. {3.0f, 3.0f, 3.0f},
  123. {6.0f, 6.0f, 6.0f},
  124. {9.0f, 9.0f, 9.0f},
  125. {9.0f, 9.0f, 9.0f}
  126. });
  127. Assert.AreEqual(expected, y[0].numpy());
  128. }
  129. [TestMethod]
  130. public void GlobalMax2DPoolingChannelsLast()
  131. {
  132. var input_array_2D = np.array(new float[,,,]
  133. {{
  134. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  135. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  136. },{
  137. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  138. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  139. },{
  140. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  141. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  142. },{
  143. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  144. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  145. }});
  146. var pool = keras.layers.GlobalMaxPooling2D();
  147. var y = pool.Apply(input_array_2D);
  148. Assert.AreEqual(4, y.shape[0]);
  149. Assert.AreEqual(5, y.shape[1]);
  150. var expected = np.array(new float[,]
  151. {
  152. {4.0f, 5.0f, 6.0f, 9.0f, 3.0f},
  153. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  154. {4.0f, 5.0f, 6.0f, 3.0f, 9.0f},
  155. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
  156. });
  157. Assert.AreEqual(expected, y[0].numpy());
  158. }
  159. [TestMethod]
  160. public void GlobalMax2DPoolingChannelsFirst()
  161. {
  162. var input_array_2D = np.array(new float[,,,]
  163. {{
  164. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  165. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  166. },{
  167. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  168. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  169. },{
  170. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  171. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  172. },{
  173. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  174. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  175. }});
  176. var pool = keras.layers.GlobalMaxPooling2D(data_format: "channels_first");
  177. var y = pool.Apply(input_array_2D);
  178. Assert.AreEqual(4, y.shape[0]);
  179. Assert.AreEqual(2, y.shape[1]);
  180. var expected = np.array(new float[,]
  181. {
  182. {9.0f, 6.0f},
  183. {9.0f, 9.0f},
  184. {9.0f, 6.0f},
  185. {9.0f, 9.0f}
  186. });
  187. Assert.AreEqual(expected, y[0].numpy());
  188. }
  189. [TestMethod]
  190. public void Max1DPoolingChannelsLast()
  191. {
  192. var x = input_array_1D;
  193. var pool = keras.layers.MaxPooling1D(pool_size: 2, strides: 1);
  194. var y = pool.Apply(x);
  195. Assert.AreEqual(4, y.shape[0]);
  196. Assert.AreEqual(2, y.shape[1]);
  197. Assert.AreEqual(5, y.shape[2]);
  198. var expected = np.array(new float[,,]
  199. {
  200. {{1.0f, 2.0f, 3.0f, 3.0f, 3.0f},
  201. { 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}},
  202. {{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  203. {4.0f, 5.0f, 6.0f, 3.0f, 3.0f}},
  204. {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  205. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}},
  206. {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  207. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}
  208. });
  209. Assert.AreEqual(expected, y[0].numpy());
  210. }
  211. [TestMethod]
  212. public void Max2DPoolingChannelsLast()
  213. {
  214. var x = np.array(new float[,,,]
  215. {{
  216. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  217. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  218. },{
  219. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  220. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  221. },{
  222. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  223. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  224. },{
  225. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  226. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  227. }});
  228. var pool = keras.layers.MaxPooling2D(pool_size: 2, strides: 1);
  229. var y = pool.Apply(x);
  230. Assert.AreEqual(4, y.shape[0]);
  231. Assert.AreEqual(1, y.shape[1]);
  232. Assert.AreEqual(2, y.shape[2]);
  233. Assert.AreEqual(5, y.shape[3]);
  234. var expected = np.array(new float[,,,]
  235. {
  236. {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  237. {4.0f, 5.0f, 6.0f, 9.0f, 3.0f}}},
  238. {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  239. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}},
  240. {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  241. {4.0f, 5.0f, 6.0f, 3.0f, 9.0f}}},
  242. {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  243. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}}
  244. });
  245. Assert.AreEqual(expected, y[0].numpy());
  246. }
  247. }
  248. }