using System.Collections.Generic;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
namespace Tensorflow.Keras.Engine
{
public class LossesContainer : Container
{
ILossFunc _user_losses;
ILossFunc _losses;
Mean _loss_metric;
bool _built;
Tensor[] _per_output_metrics;
public LossesContainer(ILossFunc losses, string[] output_names = null)
: base(output_names)
{
_user_losses = losses;
_losses = losses;
_loss_metric = new Mean(name: "loss");
_built = false;
}
///
/// Computes the overall loss.
///
///
///
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (!_built)
Build(y_pred);
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
var loss_metric_value = loss_value;
var batch_dim = array_ops.shape(y_true)[0];
var loss_values = new List();
var loss_metric_values = new List();
/*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE
|| _losses.Reduction == ReductionV2.AUTO)
loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/
loss_values.append(loss_value);
loss_metric_values.append(loss_metric_value);
if (loss_values.Count > 0)
{
var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray());
_loss_metric.update_state(total_loss_metric_value, batch_dim);
// loss_values = losses_utils.cast_losses_to_common_dtype(loss_values);
var total_loss = math_ops.add_n(loss_values.ToArray());
return total_loss;
}
else
{
// Ok for a model to have no compiled loss.
return array_ops.zeros(Shape.Null);
}
}
public void Build(Tensor y_pred)
{
_create_metrics();
_built = true;
}
void _create_metrics()
{
// _per_output_metrics = _output_names.Select(x => null);
}
public IEnumerable metrics
{
get
{
if (!_built)
return new List();
return new[] { _loss_metric };
}
}
}
}