You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Train.cs 4.9 kB

4 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using Tensorflow.Gradients;
  4. using Tensorflow.Keras.Engine.DataAdapters;
  5. using Tensorflow.Keras.Optimizers;
  6. using static Tensorflow.Binding;
  7. namespace Tensorflow.Keras.Engine
  8. {
  9. public partial class Model
  10. {
  11. Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator)
  12. {
  13. var data = iterator.next();
  14. // whether have sample_weight
  15. var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) :
  16. train_step(data_handler, data[0], data[1], data[2]);
  17. tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
  18. return outputs;
  19. }
  20. Dictionary<string, float> train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
  21. {
  22. var data = iterator.next();
  23. var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
  24. var outputs = data.Length == 2 ?
  25. train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
  26. train_step(
  27. data_handler,
  28. new Tensors(data.Take(x_size).ToArray()),
  29. new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
  30. new Tensors(data.Skip(2 * x_size).ToArray()));
  31. tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
  32. return outputs;
  33. }
  34. /// <summary>
  35. /// The logic for one training step.
  36. /// </summary>
  37. /// <param name="data_handler"></param>
  38. /// <param name="x"></param>
  39. /// <param name="y"></param>
  40. /// <returns></returns>
  41. Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y)
  42. {
  43. (x, y) = data_handler.DataAdapter.Expand1d(x, y);
  44. using var tape = tf.GradientTape();
  45. var y_pred = Apply(x, training: true);
  46. var loss = compiled_loss.Call(y, y_pred);
  47. // For custom training steps, users can just write:
  48. // trainable_variables = self.trainable_variables
  49. // gradients = tape.gradient(loss, trainable_variables)
  50. // self.optimizer.apply_gradients(zip(gradients, trainable_variables))
  51. // The _minimize call does a few extra steps unnecessary in most cases,
  52. // such as loss scaling and gradient clipping.
  53. _minimize(tape, optimizer, loss, TrainableVariables);
  54. compiled_metrics.update_state(y, y_pred);
  55. var dict = new Dictionary<string, float>();
  56. metrics.ToList().ForEach(x =>
  57. {
  58. var r = x.result();
  59. if (r.ndim > 0)
  60. {
  61. r = tf.reduce_mean(r);
  62. }
  63. dict[x.Name] = (float)r;
  64. });
  65. return dict;
  66. }
  67. Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
  68. {
  69. (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
  70. using var tape = tf.GradientTape();
  71. var y_pred = Apply(x, training: true);
  72. var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
  73. // For custom training steps, users can just write:
  74. // trainable_variables = self.trainable_variables
  75. // gradients = tape.gradient(loss, trainable_variables)
  76. // self.optimizer.apply_gradients(zip(gradients, trainable_variables))
  77. // The _minimize call does a few extra steps unnecessary in most cases,
  78. // such as loss scaling and gradient clipping.
  79. _minimize(tape, optimizer, loss, TrainableVariables);
  80. compiled_metrics.update_state(y, y_pred);
  81. var dict = new Dictionary<string, float>();
  82. metrics.ToList().ForEach(x =>
  83. {
  84. var r = x.result();
  85. if (r.ndim > 0)
  86. {
  87. r = tf.reduce_mean(r);
  88. }
  89. dict[x.Name] = (float)r;
  90. });
  91. return dict;
  92. }
  93. void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables)
  94. {
  95. var gradients = tape.gradient(loss, trainable_variables);
  96. gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables));
  97. gradients = optimizer.clip_gradients(gradients);
  98. optimizer.apply_gradients(zip(gradients, trainable_variables),
  99. experimental_aggregate_gradients: false);
  100. }
  101. }
  102. }