You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsTest.cs 11 kB


  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using static Tensorflow.Binding;
  4. namespace Tensorflow.Keras.UnitTest.Layers.Metrics;
  5. [TestClass]
  6. public class MetricsTest : EagerModeTestBase
  7. {
  8. /// <summary>
  9. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Accuracy
  10. /// </summary>
  11. [TestMethod]
  12. public void Accuracy()
  13. {
  14. var y_true = np.array(new[,] { { 1 }, { 2 }, { 3 }, { 4 } });
  15. var y_pred = np.array(new[,] { { 0f }, { 2f }, { 3f }, { 4f } });
  16. var m = tf.keras.metrics.Accuracy();
  17. m.update_state(y_true, y_pred);
  18. var r = m.result().numpy();
  19. Assert.AreEqual(r, 0.75f);
  20. m.reset_states();
  21. var weights = np.array(new[] { 1f, 1f, 0f, 0f });
  22. m.update_state(y_true, y_pred, sample_weight: weights);
  23. r = m.result().numpy();
  24. Assert.AreEqual(r, 0.5f);
  25. }
  26. /// <summary>
  27. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
  28. /// </summary>
  29. [TestMethod]
  30. public void BinaryAccuracy()
  31. {
  32. var y_true = np.array(new[,] { { 1 }, { 1 }, { 0 }, { 0 } });
  33. var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
  34. var m = tf.keras.metrics.BinaryAccuracy();
  35. m.update_state(y_true, y_pred);
  36. var r = m.result().numpy();
  37. Assert.AreEqual(r, 0.75f);
  38. m.reset_states();
  39. var weights = np.array(new[] { 1f, 0f, 0f, 1f });
  40. m.update_state(y_true, y_pred, sample_weight: weights);
  41. r = m.result().numpy();
  42. Assert.AreEqual(r, 0.5f);
  43. }
  44. /// <summary>
  45. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy
  46. /// </summary>
  47. [TestMethod]
  48. public void CategoricalAccuracy()
  49. {
  50. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  51. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  52. var m = tf.keras.metrics.CategoricalAccuracy();
  53. m.update_state(y_true, y_pred);
  54. var r = m.result().numpy();
  55. Assert.AreEqual(r, 0.5f);
  56. m.reset_states();
  57. var weights = np.array(new[] { 0.7f, 0.3f });
  58. m.update_state(y_true, y_pred, sample_weight: weights);
  59. r = m.result().numpy();
  60. Assert.AreEqual(r, 0.3f);
  61. }
  62. /// <summary>
  63. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
  64. /// </summary>
  65. [TestMethod]
  66. public void SparseCategoricalAccuracy()
  67. {
  68. var y_true = np.array(new[] { 2, 1 });
  69. var y_pred = np.array(new[,] { { 0.1f, 0.6f, 0.3f }, { 0.05f, 0.95f, 0f } });
  70. var m = tf.keras.metrics.SparseCategoricalAccuracy();
  71. m.update_state(y_true, y_pred);
  72. var r = m.result().numpy();
  73. Assert.AreEqual(r, 0.5f);
  74. m.reset_states();
  75. var weights = np.array(new[] { 0.7f, 0.3f });
  76. m.update_state(y_true, y_pred, sample_weight: weights);
  77. r = m.result().numpy();
  78. Assert.AreEqual(r, 0.3f);
  79. }
  80. /// <summary>
  81. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy
  82. /// </summary>
  83. [TestMethod]
  84. public void CategoricalCrossentropy()
  85. {
  86. var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } });
  87. var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
  88. var m = tf.keras.metrics.CategoricalCrossentropy();
  89. m.update_state(y_true, y_pred);
  90. var r = m.result().numpy();
  91. Assert.AreEqual(r, 1.1769392f);
  92. m.reset_states();
  93. var weights = np.array(new[] { 0.3f, 0.7f });
  94. m.update_state(y_true, y_pred, sample_weight: weights);
  95. r = m.result().numpy();
  96. Assert.AreEqual(r, 1.6271976f);
  97. }
  98. /// <summary>
  99. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalCrossentropy
  100. /// </summary>
  101. [TestMethod]
  102. public void SparseCategoricalCrossentropy()
  103. {
  104. var y_true = np.array(new[] { 1, 2 });
  105. var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
  106. var m = tf.keras.metrics.SparseCategoricalCrossentropy();
  107. m.update_state(y_true, y_pred);
  108. var r = m.result().numpy();
  109. Assert.AreEqual(r, 1.1769392f);
  110. }
  111. /// <summary>
  112. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity
  113. /// </summary>
  114. [TestMethod]
  115. public void CosineSimilarity()
  116. {
  117. var y_true = np.array(new[,] { { 0, 1 }, { 1, 1 } });
  118. var y_pred = np.array(new[,] { { 1f, 0f }, { 1f, 1f } });
  119. var m = tf.keras.metrics.CosineSimilarity(axis: 1);
  120. m.update_state(y_true, y_pred);
  121. var r = m.result().numpy();
  122. Assert.AreEqual(r, 0.49999997f);
  123. m.reset_states();
  124. var weights = np.array(new[] { 0.3f, 0.7f });
  125. m.update_state(y_true, y_pred, sample_weight: weights);
  126. r = m.result().numpy();
  127. Assert.AreEqual(r, 0.6999999f);
  128. }
  129. /// <summary>
  130. /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score
  131. /// </summary>
  132. [TestMethod]
  133. public void F1Score()
  134. {
  135. var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
  136. var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
  137. var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f);
  138. m.update_state(y_true, y_pred);
  139. var r = m.result().numpy();
  140. Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f });
  141. }
  142. /// <summary>
  143. /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore
  144. /// </summary>
  145. [TestMethod]
  146. public void FBetaScore()
  147. {
  148. var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
  149. var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
  150. var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f);
  151. m.update_state(y_true, y_pred);
  152. var r = m.result().numpy();
  153. Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f });
  154. }
  155. /// <summary>
  156. /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss
  157. /// </summary>
  158. [TestMethod]
  159. public void HammingLoss()
  160. {
  161. // multi-class hamming loss
  162. var y_true = np.array(new[,]
  163. {
  164. { 1, 0, 0, 0 },
  165. { 0, 0, 1, 0 },
  166. { 0, 0, 0, 1 },
  167. { 0, 1, 0, 0 }
  168. });
  169. var y_pred = np.array(new[,]
  170. {
  171. { 0.8f, 0.1f, 0.1f, 0.0f },
  172. { 0.2f, 0.0f, 0.8f, 0.0f },
  173. { 0.05f, 0.05f, 0.1f, 0.8f },
  174. { 1.0f, 0.0f, 0.0f, 0.0f }
  175. });
  176. var m = tf.keras.metrics.HammingLoss(mode: "multiclass", threshold: 0.6f);
  177. m.update_state(y_true, y_pred);
  178. var r = m.result().numpy();
  179. Assert.AreEqual(r, 0.25f);
  180. // multi-label hamming loss
  181. y_true = np.array(new[,]
  182. {
  183. { 1, 0, 1, 0 },
  184. { 0, 1, 0, 1 },
  185. { 0, 0, 0, 1 }
  186. });
  187. y_pred = np.array(new[,]
  188. {
  189. { 0.82f, 0.5f, 0.9f, 0.0f },
  190. { 0f, 1f, 0.4f, 0.98f },
  191. { 0.89f, 0.79f, 0f, 0.3f }
  192. });
  193. m = tf.keras.metrics.HammingLoss(mode: "multilabel", threshold: 0.8f);
  194. m.update_state(y_true, y_pred);
  195. r = m.result().numpy();
  196. Assert.AreEqual(r, 0.16666667f);
  197. }
  198. /// <summary>
  199. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
  200. /// </summary>
  201. [TestMethod]
  202. public void TopKCategoricalAccuracy()
  203. {
  204. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  205. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  206. var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1);
  207. m.update_state(y_true, y_pred);
  208. var r = m.result().numpy();
  209. Assert.AreEqual(r, 0.5f);
  210. m.reset_states();
  211. var weights = np.array(new[] { 0.7f, 0.3f });
  212. m.update_state(y_true, y_pred, sample_weight: weights);
  213. r = m.result().numpy();
  214. Assert.AreEqual(r, 0.3f);
  215. }
  216. /// <summary>
  217. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseTopKCategoricalAccuracy
  218. /// </summary>
  219. [TestMethod]
  220. public void SparseTopKCategoricalAccuracy()
  221. {
  222. var y_true = np.array(new[] { 2, 1 });
  223. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  224. var m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k: 1);
  225. m.update_state(y_true, y_pred);
  226. var r = m.result().numpy();
  227. Assert.AreEqual(r, 0.5f);
  228. m.reset_states();
  229. var weights = np.array(new[] { 0.7f, 0.3f });
  230. m.update_state(y_true, y_pred, sample_weight: weights);
  231. r = m.result().numpy();
  232. Assert.AreEqual(r, 0.3f);
  233. }
  234. /// <summary>
  235. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
  236. /// </summary>
  237. [TestMethod]
  238. public void top_k_categorical_accuracy()
  239. {
  240. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  241. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  242. var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
  243. Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
  244. }
  245. /// <summary>
  246. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision
  247. /// </summary>
  248. [TestMethod]
  249. public void Precision()
  250. {
  251. var y_true = np.array(new[] { 0, 1, 1, 1 });
  252. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  253. var m = tf.keras.metrics.Precision();
  254. m.update_state(y_true, y_pred);
  255. var r = m.result().numpy();
  256. Assert.AreEqual(r, 0.6666667f);
  257. m.reset_states();
  258. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  259. m.update_state(y_true, y_pred, sample_weight: weights);
  260. r = m.result().numpy();
  261. Assert.AreEqual(r, 1f);
  262. // With top_k=2, it will calculate precision over y_true[:2]
  263. // and y_pred[:2]
  264. m = tf.keras.metrics.Precision(top_k: 2);
  265. m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
  266. r = m.result().numpy();
  267. Assert.AreEqual(r, 0f);
  268. // With top_k=4, it will calculate precision over y_true[:4]
  269. // and y_pred[:4]
  270. m = tf.keras.metrics.Precision(top_k: 4);
  271. m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
  272. r = m.result().numpy();
  273. Assert.AreEqual(r, 0.5f);
  274. }
  275. /// <summary>
  276. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
  277. /// </summary>
  278. [TestMethod]
  279. public void Recall()
  280. {
  281. var y_true = np.array(new[] { 0, 1, 1, 1 });
  282. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  283. var m = tf.keras.metrics.Recall();
  284. m.update_state(y_true, y_pred);
  285. var r = m.result().numpy();
  286. Assert.AreEqual(r, 0.6666667f);
  287. m.reset_states();
  288. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  289. m.update_state(y_true, y_pred, sample_weight: weights);
  290. r = m.result().numpy();
  291. Assert.AreEqual(r, 1f);
  292. }
  293. }