You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsTest.cs 12 kB


  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. using Tensorflow;
  8. using Tensorflow.NumPy;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. namespace TensorFlowNET.Keras.UnitTest;
  12. [TestClass]
  13. public class MetricsTest : EagerModeTestBase
  14. {
  15. /// <summary>
  16. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Accuracy
  17. /// </summary>
  18. [TestMethod]
  19. public void Accuracy()
  20. {
  21. var y_true = np.array(new[,] { { 1 }, { 2 }, { 3 }, { 4 } });
  22. var y_pred = np.array(new[,] { { 0f }, { 2f }, { 3f }, { 4f } });
  23. var m = tf.keras.metrics.Accuracy();
  24. m.update_state(y_true, y_pred);
  25. var r = m.result().numpy();
  26. Assert.AreEqual(r, 0.75f);
  27. m.reset_states();
  28. var weights = np.array(new[] { 1f, 1f, 0f, 0f });
  29. m.update_state(y_true, y_pred, sample_weight: weights);
  30. r = m.result().numpy();
  31. Assert.AreEqual(r, 0.5f);
  32. }
  33. /// <summary>
  34. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
  35. /// </summary>
  36. [TestMethod]
  37. public void BinaryAccuracy()
  38. {
  39. var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
  40. var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
  41. var m = tf.keras.metrics.BinaryAccuracy();
  42. m.update_state(y_true, y_pred);
  43. var r = m.result().numpy();
  44. Assert.AreEqual(r, 0.75f);
  45. m.reset_states();
  46. var weights = np.array(new[] { 1f, 0f, 0f, 1f });
  47. m.update_state(y_true, y_pred, sample_weight: weights);
  48. r = m.result().numpy();
  49. Assert.AreEqual(r, 0.5f);
  50. }
  51. /// <summary>
  52. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy
  53. /// </summary>
  54. [TestMethod]
  55. public void CategoricalAccuracy()
  56. {
  57. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  58. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  59. var m = tf.keras.metrics.CategoricalAccuracy();
  60. m.update_state(y_true, y_pred);
  61. var r = m.result().numpy();
  62. Assert.AreEqual(r, 0.5f);
  63. m.reset_states();
  64. var weights = np.array(new[] { 0.7f, 0.3f });
  65. m.update_state(y_true, y_pred, sample_weight: weights);
  66. r = m.result().numpy();
  67. Assert.AreEqual(r, 0.3f);
  68. }
  69. /// <summary>
  70. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
  71. /// </summary>
  72. [TestMethod]
  73. public void SparseCategoricalAccuracy()
  74. {
  75. var y_true = np.array(new[] { 2, 1 });
  76. var y_pred = np.array(new[,] { { 0.1f, 0.6f, 0.3f }, { 0.05f, 0.95f, 0f } });
  77. var m = tf.keras.metrics.SparseCategoricalAccuracy();
  78. m.update_state(y_true, y_pred);
  79. var r = m.result().numpy();
  80. Assert.AreEqual(r, 0.5f);
  81. m.reset_states();
  82. var weights = np.array(new[] { 0.7f, 0.3f });
  83. m.update_state(y_true, y_pred, sample_weight: weights);
  84. r = m.result().numpy();
  85. Assert.AreEqual(r, 0.3f);
  86. }
  87. /// <summary>
  88. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy
  89. /// </summary>
  90. [TestMethod]
  91. public void CategoricalCrossentropy()
  92. {
  93. var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } });
  94. var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
  95. var m = tf.keras.metrics.CategoricalCrossentropy();
  96. m.update_state(y_true, y_pred);
  97. var r = m.result().numpy();
  98. Assert.AreEqual(r, 1.1769392f);
  99. m.reset_states();
  100. var weights = np.array(new[] { 0.3f, 0.7f });
  101. m.update_state(y_true, y_pred, sample_weight: weights);
  102. r = m.result().numpy();
  103. Assert.AreEqual(r, 1.6271976f);
  104. }
  105. /// <summary>
  106. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalCrossentropy
  107. /// </summary>
  108. [TestMethod]
  109. public void SparseCategoricalCrossentropy()
  110. {
  111. var y_true = np.array(new[] { 1, 2 });
  112. var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
  113. var m = tf.keras.metrics.SparseCategoricalCrossentropy();
  114. m.update_state(y_true, y_pred);
  115. var r = m.result().numpy();
  116. Assert.AreEqual(r, 1.1769392f);
  117. }
  118. /// <summary>
  119. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity
  120. /// </summary>
  121. [TestMethod]
  122. public void CosineSimilarity()
  123. {
  124. var y_true = np.array(new[,] { { 0, 1 }, { 1, 1 } });
  125. var y_pred = np.array(new[,] { { 1f, 0f }, { 1f, 1f } });
  126. var m = tf.keras.metrics.CosineSimilarity(axis: 1);
  127. m.update_state(y_true, y_pred);
  128. var r = m.result().numpy();
  129. Assert.AreEqual(r, 0.49999997f);
  130. m.reset_states();
  131. var weights = np.array(new[] { 0.3f, 0.7f });
  132. m.update_state(y_true, y_pred, sample_weight: weights);
  133. r = m.result().numpy();
  134. Assert.AreEqual(r, 0.6999999f);
  135. }
  136. /// <summary>
  137. /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score
  138. /// </summary>
  139. [TestMethod]
  140. public void F1Score()
  141. {
  142. var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
  143. var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
  144. var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f);
  145. m.update_state(y_true, y_pred);
  146. var r = m.result().numpy();
  147. Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f });
  148. }
  149. /// <summary>
  150. /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore
  151. /// </summary>
  152. [TestMethod]
  153. public void FBetaScore()
  154. {
  155. var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
  156. var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
  157. var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f);
  158. m.update_state(y_true, y_pred);
  159. var r = m.result().numpy();
  160. Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f });
  161. }
  162. /// <summary>
  163. /// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss
  164. /// </summary>
  165. [TestMethod]
  166. public void HammingLoss()
  167. {
  168. // multi-class hamming loss
  169. var y_true = np.array(new[,]
  170. {
  171. { 1, 0, 0, 0 },
  172. { 0, 0, 1, 0 },
  173. { 0, 0, 0, 1 },
  174. { 0, 1, 0, 0 }
  175. });
  176. var y_pred = np.array(new[,]
  177. {
  178. { 0.8f, 0.1f, 0.1f, 0.0f },
  179. { 0.2f, 0.0f, 0.8f, 0.0f },
  180. { 0.05f, 0.05f, 0.1f, 0.8f },
  181. { 1.0f, 0.0f, 0.0f, 0.0f }
  182. });
  183. var m = tf.keras.metrics.HammingLoss(mode: "multiclass", threshold: 0.6f);
  184. m.update_state(y_true, y_pred);
  185. var r = m.result().numpy();
  186. Assert.AreEqual(r, 0.25f);
  187. // multi-label hamming loss
  188. y_true = np.array(new[,]
  189. {
  190. { 1, 0, 1, 0 },
  191. { 0, 1, 0, 1 },
  192. { 0, 0, 0, 1 }
  193. });
  194. y_pred = np.array(new[,]
  195. {
  196. { 0.82f, 0.5f, 0.9f, 0.0f },
  197. { 0f, 1f, 0.4f, 0.98f },
  198. { 0.89f, 0.79f, 0f, 0.3f }
  199. });
  200. m = tf.keras.metrics.HammingLoss(mode: "multilabel", threshold: 0.8f);
  201. m.update_state(y_true, y_pred);
  202. r = m.result().numpy();
  203. Assert.AreEqual(r, 0.16666667f);
  204. }
  205. /// <summary>
  206. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
  207. /// </summary>
  208. [TestMethod]
  209. public void TopKCategoricalAccuracy()
  210. {
  211. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  212. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  213. var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1);
  214. m.update_state(y_true, y_pred);
  215. var r = m.result().numpy();
  216. Assert.AreEqual(r, 0.5f);
  217. m.reset_states();
  218. var weights = np.array(new[] { 0.7f, 0.3f });
  219. m.update_state(y_true, y_pred, sample_weight: weights);
  220. r = m.result().numpy();
  221. Assert.AreEqual(r, 0.3f);
  222. }
  223. /// <summary>
  224. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseTopKCategoricalAccuracy
  225. /// </summary>
  226. [TestMethod]
  227. public void SparseTopKCategoricalAccuracy()
  228. {
  229. var y_true = np.array(new[] { 2, 1 });
  230. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  231. var m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k: 1);
  232. m.update_state(y_true, y_pred);
  233. var r = m.result().numpy();
  234. Assert.AreEqual(r, 0.5f);
  235. m.reset_states();
  236. var weights = np.array(new[] { 0.7f, 0.3f });
  237. m.update_state(y_true, y_pred, sample_weight: weights);
  238. r = m.result().numpy();
  239. Assert.AreEqual(r, 0.3f);
  240. }
  241. /// <summary>
  242. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
  243. /// </summary>
  244. [TestMethod]
  245. public void top_k_categorical_accuracy()
  246. {
  247. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  248. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  249. var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
  250. Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
  251. }
  252. /// <summary>
  253. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision
  254. /// </summary>
  255. [TestMethod]
  256. public void Precision()
  257. {
  258. var y_true = np.array(new[] { 0, 1, 1, 1 });
  259. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  260. var m = tf.keras.metrics.Precision();
  261. m.update_state(y_true, y_pred);
  262. var r = m.result().numpy();
  263. Assert.AreEqual(r, 0.6666667f);
  264. m.reset_states();
  265. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  266. m.update_state(y_true, y_pred, sample_weight: weights);
  267. r = m.result().numpy();
  268. Assert.AreEqual(r, 1f);
  269. // With top_k=2, it will calculate precision over y_true[:2]
  270. // and y_pred[:2]
  271. m = tf.keras.metrics.Precision(top_k: 2);
  272. m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
  273. r = m.result().numpy();
  274. Assert.AreEqual(r, 0f);
  275. // With top_k=4, it will calculate precision over y_true[:4]
  276. // and y_pred[:4]
  277. m = tf.keras.metrics.Precision(top_k: 4);
  278. m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
  279. r = m.result().numpy();
  280. Assert.AreEqual(r, 0.5f);
  281. }
  282. /// <summary>
  283. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
  284. /// </summary>
  285. [TestMethod]
  286. public void Recall()
  287. {
  288. var y_true = np.array(new[] { 0, 1, 1, 1 });
  289. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  290. var m = tf.keras.metrics.Recall();
  291. m.update_state(y_true, y_pred);
  292. var r = m.result().numpy();
  293. Assert.AreEqual(r, 0.6666667f);
  294. m.reset_states();
  295. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  296. m.update_state(y_true, y_pred, sample_weight: weights);
  297. r = m.result().numpy();
  298. Assert.AreEqual(r, 1f);
  299. }
  300. }