using System.Collections.Generic;
using System.Linq;
using Tensorflow.Gradients;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine
{
public partial class Model
{
IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator)
{
var data = iterator.next();
var outputs = train_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
///
/// The logic for one training step.
///
///
///
List<(string, Tensor)> train_step(Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred);
// For custom training steps, users can just write:
// trainable_variables = self.trainable_variables
// gradients = tape.gradient(loss, trainable_variables)
// self.optimizer.apply_gradients(zip(gradients, trainable_variables))
// The _minimize call does a few extra steps unnecessary in most cases,
// such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, trainable_variables);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToList();
}
void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List trainable_variables)
{
var gradients = tape.gradient(loss, trainable_variables);
gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables));
gradients = optimizer._clip_gradients(gradients);
optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
experimental_aggregate_gradients: false);
}
}
}