You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Train.cs 2.2 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using Tensorflow.Gradients;
  4. using Tensorflow.Keras.Optimizers;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.Keras.Engine
  7. {
  8. public partial class Model
  9. {
  10. IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator)
  11. {
  12. var data = iterator.next();
  13. var outputs = train_step(data[0], data[1]);
  14. tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
  15. return outputs;
  16. }
  17. /// <summary>
  18. /// The logic for one training step.
  19. /// </summary>
  20. /// <param name="data"></param>
  21. /// <returns></returns>
  22. List<(string, Tensor)> train_step(Tensor x, Tensor y)
  23. {
  24. (x, y) = data_handler.DataAdapter.Expand1d(x, y);
  25. using var tape = tf.GradientTape();
  26. var y_pred = Apply(x, training: true);
  27. var loss = compiled_loss.Call(y, y_pred);
  28. // For custom training steps, users can just write:
  29. // trainable_variables = self.trainable_variables
  30. // gradients = tape.gradient(loss, trainable_variables)
  31. // self.optimizer.apply_gradients(zip(gradients, trainable_variables))
  32. // The _minimize call does a few extra steps unnecessary in most cases,
  33. // such as loss scaling and gradient clipping.
  34. _minimize(tape, optimizer, loss, trainable_variables);
  35. compiled_metrics.update_state(y, y_pred);
  36. return metrics.Select(x => (x.Name, x.result())).ToList();
  37. }
  38. void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables)
  39. {
  40. var gradients = tape.gradient(loss, trainable_variables);
  41. gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables));
  42. gradients = optimizer._clip_gradients(gradients);
  43. optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
  44. experimental_aggregate_gradients: false);
  45. }
  46. }
  47. }