You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsTest.cs 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using System.Threading.Tasks;
  7. using Tensorflow;
  8. using Tensorflow.NumPy;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. namespace TensorFlowNET.Keras.UnitTest;
  12. [TestClass]
  13. public class MetricsTest : EagerModeTestBase
  14. {
  15. /// <summary>
  16. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
  17. /// </summary>
  18. [TestMethod]
  19. public void TopKCategoricalAccuracy()
  20. {
  21. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  22. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  23. var m = tf.keras.metrics.TopKCategoricalAccuracy(k: 1);
  24. m.update_state(y_true, y_pred);
  25. var r = m.result().numpy();
  26. Assert.AreEqual(r, 0.5f);
  27. m.reset_states();
  28. var weights = np.array(new[] { 0.7f, 0.3f });
  29. m.update_state(y_true, y_pred, sample_weight: weights);
  30. r = m.result().numpy();
  31. Assert.AreEqual(r, 0.3f);
  32. }
  33. /// <summary>
  34. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
  35. /// </summary>
  36. [TestMethod]
  37. public void top_k_categorical_accuracy()
  38. {
  39. var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
  40. var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
  41. var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
  42. Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
  43. }
  44. /// <summary>
  45. /// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
  46. /// </summary>
  47. [TestMethod]
  48. public void Recall()
  49. {
  50. var y_true = np.array(new[] { 0, 1, 1, 1 });
  51. var y_pred = np.array(new[] { 1, 0, 1, 1 });
  52. var m = tf.keras.metrics.Recall();
  53. m.update_state(y_true, y_pred);
  54. var r = m.result().numpy();
  55. Assert.AreEqual(r, 0.6666667f);
  56. m.reset_states();
  57. var weights = np.array(new[] { 0f, 0f, 1f, 0f });
  58. m.update_state(y_true, y_pred, sample_weight: weights);
  59. r = m.result().numpy();
  60. Assert.AreEqual(r, 1f);
  61. }
  62. }