You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LossesContainer.cs 2.6 kB

4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. using System.Collections.Generic;
  2. using Tensorflow.Keras.Losses;
  3. using Tensorflow.Keras.Metrics;
  4. namespace Tensorflow.Keras.Engine
  5. {
  6. public class LossesContainer : Container
  7. {
  8. ILossFunc _user_losses;
  9. ILossFunc _losses;
  10. Mean _loss_metric;
  11. bool _built;
  12. Tensor[] _per_output_metrics;
  13. public LossesContainer(ILossFunc losses, string[] output_names = null)
  14. : base(output_names)
  15. {
  16. _user_losses = losses;
  17. _losses = losses;
  18. _loss_metric = new Mean(name: "loss");
  19. _built = false;
  20. }
  21. /// <summary>
  22. /// Computes the overall loss.
  23. /// </summary>
  24. /// <param name="y_true"></param>
  25. /// <param name="y_pred"></param>
  26. public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
  27. {
  28. if (!_built)
  29. Build(y_pred);
  30. var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
  31. var loss_metric_value = loss_value;
  32. var batch_dim = array_ops.shape(y_true)[0];
  33. var loss_values = new List<Tensor>();
  34. var loss_metric_values = new List<Tensor>();
  35. /*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE
  36. || _losses.Reduction == ReductionV2.AUTO)
  37. loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/
  38. loss_values.append(loss_value);
  39. loss_metric_values.append(loss_metric_value);
  40. if (loss_values.Count > 0)
  41. {
  42. var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray());
  43. _loss_metric.update_state(total_loss_metric_value, batch_dim);
  44. // loss_values = losses_utils.cast_losses_to_common_dtype(loss_values);
  45. var total_loss = math_ops.add_n(loss_values.ToArray());
  46. return total_loss;
  47. }
  48. else
  49. {
  50. // Ok for a model to have no compiled loss.
  51. return array_ops.zeros(Shape.Null);
  52. }
  53. }
  54. public void Build(Tensor y_pred)
  55. {
  56. _create_metrics();
  57. _built = true;
  58. }
  59. void _create_metrics()
  60. {
  61. // _per_output_metrics = _output_names.Select(x => null);
  62. }
  63. public IEnumerable<Metric> metrics
  64. {
  65. get
  66. {
  67. if (!_built)
  68. return new List<Metric>();
  69. return new[] { _loss_metric };
  70. }
  71. }
  72. }
  73. }