You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsTest.cs 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. using Tensorflow;
  8. using Tensorflow.NumPy;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. namespace TensorFlowNET.Keras.UnitTest;
  12. [TestClass]
  13. public class MetricsTest : EagerModeTestBase
  14. {
  15. /// <summary>
  16. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
  17. /// </summary>
  18. [TestMethod]
  19. public void BinaryAccuracy()
  20. {
  21. var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
  22. var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
  23. var m = tf.keras.metrics.BinaryAccuracy();
  24. /*m.update_state(y_true, y_pred);
  25. var r = m.result().numpy();
  26. Assert.AreEqual(r, 0.75f);
  27. m.reset_states();*/
  28. var weights = np.array(new[] { 1f, 0f, 0f, 1f });
  29. m.update_state(y_true, y_pred, sample_weight: weights);
  30. var r = m.result().numpy();
  31. Assert.AreEqual(r, 0.5f);
  32. }
  33. /// <summary>
  34. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalAccuracy
  35. /// </summary>
  36. [TestMethod]
  37. public void CategoricalAccuracy()
  38. {
  39. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  40. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  41. var m = tf.keras.metrics.CategoricalAccuracy();
  42. m.update_state(y_true, y_pred);
  43. var r = m.result().numpy();
  44. Assert.AreEqual(r, 0.5f);
  45. m.reset_states();
  46. var weights = np.array(new[] { 0.7f, 0.3f });
  47. m.update_state(y_true, y_pred, sample_weight: weights);
  48. r = m.result().numpy();
  49. Assert.AreEqual(r, 0.3f);
  50. }
  51. /// <summary>
  52. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CategoricalCrossentropy
  53. /// </summary>
  54. [TestMethod]
  55. public void CategoricalCrossentropy()
  56. {
  57. var y_true = np.array(new[,] { { 0, 1, 0 }, { 0, 0, 1 } });
  58. var y_pred = np.array(new[,] { { 0.05f, 0.95f, 0f }, { 0.1f, 0.8f, 0.1f } });
  59. var m = tf.keras.metrics.CategoricalCrossentropy();
  60. m.update_state(y_true, y_pred);
  61. var r = m.result().numpy();
  62. Assert.AreEqual(r, 1.1769392f);
  63. m.reset_states();
  64. var weights = np.array(new[] { 0.3f, 0.7f });
  65. m.update_state(y_true, y_pred, sample_weight: weights);
  66. r = m.result().numpy();
  67. Assert.AreEqual(r, 1.6271976f);
  68. }
  69. /// <summary>
  70. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
  71. /// </summary>
  72. [TestMethod]
  73. public void TopKCategoricalAccuracy()
  74. {
  75. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  76. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  77. var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1);
  78. m.update_state(y_true, y_pred);
  79. var r = m.result().numpy();
  80. Assert.AreEqual(r, 0.5f);
  81. m.reset_states();
  82. var weights = np.array(new[] { 0.7f, 0.3f });
  83. m.update_state(y_true, y_pred, sample_weight: weights);
  84. r = m.result().numpy();
  85. Assert.AreEqual(r, 0.3f);
  86. }
  87. /// <summary>
  88. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
  89. /// </summary>
  90. [TestMethod]
  91. public void top_k_categorical_accuracy()
  92. {
  93. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  94. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  95. var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
  96. Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
  97. }
  98. /// <summary>
  99. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision
  100. /// </summary>
  101. [TestMethod]
  102. public void Precision()
  103. {
  104. var y_true = np.array(new[] { 0, 1, 1, 1 });
  105. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  106. var m = tf.keras.metrics.Precision();
  107. m.update_state(y_true, y_pred);
  108. var r = m.result().numpy();
  109. Assert.AreEqual(r, 0.6666667f);
  110. m.reset_states();
  111. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  112. m.update_state(y_true, y_pred, sample_weight: weights);
  113. r = m.result().numpy();
  114. Assert.AreEqual(r, 1f);
  115. // With top_k=2, it will calculate precision over y_true[:2]
  116. // and y_pred[:2]
  117. m = tf.keras.metrics.Precision(top_k: 2);
  118. m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
  119. r = m.result().numpy();
  120. Assert.AreEqual(r, 0f);
  121. // With top_k=4, it will calculate precision over y_true[:4]
  122. // and y_pred[:4]
  123. m = tf.keras.metrics.Precision(top_k: 4);
  124. m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
  125. r = m.result().numpy();
  126. Assert.AreEqual(r, 0.5f);
  127. }
  128. /// <summary>
  129. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
  130. /// </summary>
  131. [TestMethod]
  132. public void Recall()
  133. {
  134. var y_true = np.array(new[] { 0, 1, 1, 1 });
  135. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  136. var m = tf.keras.metrics.Recall();
  137. m.update_state(y_true, y_pred);
  138. var r = m.result().numpy();
  139. Assert.AreEqual(r, 0.6666667f);
  140. m.reset_states();
  141. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  142. m.update_state(y_true, y_pred, sample_weight: weights);
  143. r = m.result().numpy();
  144. Assert.AreEqual(r, 1f);
  145. }
  146. }