using System.Collections.Generic; using System.Linq; using Tensorflow.Gradients; using Tensorflow.Keras.Optimizers; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { public partial class Model { IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator) { var data = iterator.next(); var outputs = train_step(data[0], data[1]); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } /// /// The logic for one training step. /// /// /// List<(string, Tensor)> train_step(Tensor x, Tensor y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); var y_pred = Apply(x, training: true); var loss = compiled_loss.Call(y, y_pred); // For custom training steps, users can just write: // trainable_variables = self.trainable_variables // gradients = tape.gradient(loss, trainable_variables) // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) // The _minimize call does a few extra steps unnecessary in most cases, // such as loss scaling and gradient clipping. _minimize(tape, optimizer, loss, trainable_variables); compiled_metrics.update_state(y, y_pred); return metrics.Select(x => (x.Name, x.result())).ToList(); } void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List trainable_variables) { var gradients = tape.gradient(loss, trainable_variables); gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables)); gradients = optimizer._clip_gradients(gradients); optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), experimental_aggregate_gradients: false); } } }