You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PoolingTest.cs 9.6 kB


  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System.Linq;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. using Microsoft.VisualBasic;
  8. namespace TensorFlowNET.Keras.UnitTest
  9. {
  10. /// <summary>
  11. /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
  12. /// </summary>
  13. [TestClass]
  14. public class PoolingTest : EagerModeTestBase
  15. {
  16. private NDArray input_array_1D = np.array(new float[,,]
  17. {
  18. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  19. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  20. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  21. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  22. });
  23. private NDArray input_array_2D = np.array(new float[,,,]
  24. {{
  25. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  26. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  27. },{
  28. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  29. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  30. },{
  31. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
  32. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  33. },{
  34. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  35. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  36. }});
  37. [TestMethod]
  38. public void GlobalAverage1DPoolingChannelsLast()
  39. {
  40. var pool = keras.layers.GlobalAveragePooling1D();
  41. var y = pool.Apply(input_array_1D);
  42. Assert.AreEqual(4, y.shape[0]);
  43. Assert.AreEqual(5, y.shape[1]);
  44. var expected = np.array(new float[,]
  45. {
  46. {1,2,3,3,3},
  47. {4,5,6,3,3},
  48. {7,8,9,3,3},
  49. {7,8,9,3,3}
  50. });
  51. Assert.AreEqual(expected, y[0].numpy());
  52. }
  53. [TestMethod]
  54. public void GlobalAverage1DPoolingChannelsFirst()
  55. {
  56. var pool = keras.layers.GlobalAveragePooling1D(data_format: "channels_first");
  57. var y = pool.Apply(input_array_1D);
  58. Assert.AreEqual(4, y.shape[0]);
  59. Assert.AreEqual(3, y.shape[1]);
  60. var expected = np.array(new float[,]
  61. {
  62. {2.4f, 2.4f, 2.4f},
  63. {4.2f, 4.2f, 4.2f},
  64. {6.0f, 6.0f, 6.0f},
  65. {6.0f, 6.0f, 6.0f}
  66. });
  67. Assert.AreEqual(expected, y[0].numpy());
  68. }
  69. [TestMethod]
  70. public void GlobalAverage2DPoolingChannelsLast()
  71. {
  72. var pool = keras.layers.GlobalAveragePooling2D();
  73. var y = pool.Apply(input_array_2D);
  74. Assert.AreEqual(4, y.shape[0]);
  75. Assert.AreEqual(5, y.shape[1]);
  76. var expected = np.array(new float[,]
  77. {
  78. {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
  79. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  80. {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
  81. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
  82. });
  83. Assert.AreEqual(expected, y[0].numpy());
  84. }
  85. [TestMethod]
  86. public void GlobalAverage2DPoolingChannelsFirst()
  87. {
  88. var pool = keras.layers.GlobalAveragePooling2D(data_format: "channels_first");
  89. var y = pool.Apply(input_array_2D);
  90. Assert.AreEqual(4, y.shape[0]);
  91. Assert.AreEqual(2, y.shape[1]);
  92. var expected = np.array(new float[,]
  93. {
  94. {2.4f, 4.2f},
  95. {6.0f, 6.0f},
  96. {2.4f, 4.2f},
  97. {6.0f, 6.0f}
  98. });
  99. Assert.AreEqual(expected, y[0].numpy());
  100. }
  101. [TestMethod]
  102. public void GlobalMax1DPoolingChannelsLast()
  103. {
  104. var pool = keras.layers.GlobalMaxPooling1D();
  105. var y = pool.Apply(input_array_1D);
  106. Assert.AreEqual(4, y.shape[0]);
  107. Assert.AreEqual(5, y.shape[1]);
  108. var expected = np.array(new float[,]
  109. {
  110. {1,2,3,3,3},
  111. {4,5,6,3,3},
  112. {7,8,9,3,3},
  113. {7,8,9,3,3}
  114. });
  115. Assert.AreEqual(expected, y[0].numpy());
  116. }
  117. [TestMethod]
  118. public void GlobalMax1DPoolingChannelsFirst()
  119. {
  120. var pool = keras.layers.GlobalMaxPooling1D(data_format: "channels_first");
  121. var y = pool.Apply(input_array_1D);
  122. Assert.AreEqual(4, y.shape[0]);
  123. Assert.AreEqual(3, y.shape[1]);
  124. var expected = np.array(new float[,]
  125. {
  126. {3.0f, 3.0f, 3.0f},
  127. {6.0f, 6.0f, 6.0f},
  128. {9.0f, 9.0f, 9.0f},
  129. {9.0f, 9.0f, 9.0f}
  130. });
  131. Assert.AreEqual(expected, y[0].numpy());
  132. }
  133. [TestMethod]
  134. public void GlobalMax2DPoolingChannelsLast()
  135. {
  136. var input_array_2D = np.array(new float[,,,]
  137. {{
  138. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  139. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  140. },{
  141. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  142. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  143. },{
  144. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  145. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  146. },{
  147. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  148. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  149. }});
  150. var pool = keras.layers.GlobalMaxPooling2D();
  151. var y = pool.Apply(input_array_2D);
  152. Assert.AreEqual(4, y.shape[0]);
  153. Assert.AreEqual(5, y.shape[1]);
  154. var expected = np.array(new float[,]
  155. {
  156. {4.0f, 5.0f, 6.0f, 9.0f, 3.0f},
  157. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  158. {4.0f, 5.0f, 6.0f, 3.0f, 9.0f},
  159. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
  160. });
  161. Assert.AreEqual(expected, y[0].numpy());
  162. }
  163. [TestMethod]
  164. public void GlobalMax2DPoolingChannelsFirst()
  165. {
  166. var input_array_2D = np.array(new float[,,,]
  167. {{
  168. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  169. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  170. },{
  171. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  172. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  173. },{
  174. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  175. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  176. },{
  177. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  178. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  179. }});
  180. var pool = keras.layers.GlobalMaxPooling2D(data_format: "channels_first");
  181. var y = pool.Apply(input_array_2D);
  182. Assert.AreEqual(4, y.shape[0]);
  183. Assert.AreEqual(2, y.shape[1]);
  184. var expected = np.array(new float[,]
  185. {
  186. {9.0f, 6.0f},
  187. {9.0f, 9.0f},
  188. {9.0f, 6.0f},
  189. {9.0f, 9.0f}
  190. });
  191. Assert.AreEqual(expected, y[0].numpy());
  192. }
  193. [TestMethod]
  194. public void Max1DPoolingChannelsLast()
  195. {
  196. var x = input_array_1D;
  197. var pool = keras.layers.MaxPooling1D(pool_size:2, strides:1);
  198. var y = pool.Apply(x);
  199. Assert.AreEqual(4, y.shape[0]);
  200. Assert.AreEqual(2, y.shape[1]);
  201. Assert.AreEqual(5, y.shape[2]);
  202. var expected = np.array(new float[,,]
  203. {
  204. {{1.0f, 2.0f, 3.0f, 3.0f, 3.0f},
  205. { 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}},
  206. {{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  207. {4.0f, 5.0f, 6.0f, 3.0f, 3.0f}},
  208. {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  209. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}},
  210. {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  211. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}
  212. });
  213. Assert.AreEqual(expected, y[0].numpy());
  214. }
  215. [TestMethod]
  216. public void Max2DPoolingChannelsLast()
  217. {
  218. var x = np.array(new float[,,,]
  219. {{
  220. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
  221. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  222. },{
  223. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  224. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  225. },{
  226. {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
  227. {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
  228. },{
  229. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
  230. {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
  231. }});
  232. var pool = keras.layers.MaxPooling2D(pool_size: 2, strides: 1);
  233. var y = pool.Apply(x);
  234. Assert.AreEqual(4, y.shape[0]);
  235. Assert.AreEqual(1, y.shape[1]);
  236. Assert.AreEqual(2, y.shape[2]);
  237. Assert.AreEqual(5, y.shape[3]);
  238. var expected = np.array(new float[,,,]
  239. {
  240. {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  241. {4.0f, 5.0f, 6.0f, 9.0f, 3.0f}}},
  242. {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  243. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}},
  244. {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
  245. {4.0f, 5.0f, 6.0f, 3.0f, 9.0f}}},
  246. {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
  247. {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}}
  248. });
  249. Assert.AreEqual(expected, y[0].numpy());
  250. }
  251. }
  252. }