You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.cs 7.2 kB

2 years ago
2 years ago
4 years ago
4 years ago

  1. using System.Diagnostics;
  2. using Tensorflow.Framework.Models;
  3. using Tensorflow.Keras.ArgsDefinition;
  4. using Tensorflow.Keras.Losses;
  5. using Tensorflow.Keras.Saving;
  6. using Tensorflow.Keras.Saving.SavedModel;
  7. using Tensorflow.Keras.Utils;
  8. using Tensorflow.Train;
  9. using Tensorflow.Util;
  10. namespace Tensorflow.Keras.Engine
  11. {
  12. /// <summary>
  13. /// `Model` groups layers into an object with training and inference features.
  14. /// </summary>
  15. public partial class Model : Layer, IModel
  16. {
  17. #pragma warning disable CS0169 // The field 'Model._cloning' is never used
  18. bool _cloning;
  19. #pragma warning restore CS0169 // The field 'Model._cloning' is never used
  20. #pragma warning disable CS0108 // Member hides inherited member; missing new keyword
  21. #pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
  22. bool _is_compiled;
  23. #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
  24. #pragma warning restore CS0108 // Member hides inherited member; missing new keyword
  25. ILossFunc loss;
  26. IOptimizer optimizer;
  27. IVariableV1 _steps_per_execution;
  28. protected bool _is_graph_network;
  29. public Tensors inputs;
  30. protected Tensors outputs;
  31. protected List<string> input_names;
  32. public string[] output_names;
  33. IVariableV1 _train_counter;
  34. IVariableV1 _test_counter;
  35. IVariableV1 _predict_counter;
  36. bool _base_model_initialized;
  37. bool stop_training;
  38. TensorSpec _saved_model_inputs_spec;
  39. public bool IsGraphNetwork => _is_graph_network;
  40. public IOptimizer Optimizer
  41. {
  42. get => optimizer;
  43. set => optimizer = value;
  44. }
  45. public bool Stop_training
  46. {
  47. get => stop_training;
  48. set => stop_training = value;
  49. }
  50. public Model(ModelArgs args)
  51. : base(args)
  52. {
  53. _init_batch_counters();
  54. }
  55. public void _set_inputs(TensorSpec inputs)
  56. {
  57. _set_save_spec(inputs);
  58. }
  59. internal void _set_save_spec(TensorSpec inputs)
  60. {
  61. if(_saved_model_inputs_spec is not null)
  62. {
  63. return;
  64. }
  65. var input_names = this.input_names;
  66. if(input_names is null || input_names.Count == 0)
  67. {
  68. input_names = compile_utils.create_pseudo_input_names(inputs);
  69. }
  70. var flat_inputs = nest.flatten(inputs);
  71. List<TensorSpec> specs = new();
  72. foreach(var (name, tensor) in zip(input_names, flat_inputs))
  73. {
  74. specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name));
  75. }
  76. var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec;
  77. Debug.Assert(specs is not null);
  78. _saved_model_inputs_spec = packed_specs;
  79. if(this is Sequential && _buildInputShape is null)
  80. {
  81. _buildInputShape = nest.map_structure<TensorSpec, TensorShapeConfig>(x => x is null ? null : x.shape, packed_specs);
  82. }
  83. }
  84. internal override void Initialize(LayerArgs args)
  85. {
  86. _init_batch_counters();
  87. base.Initialize(args);
  88. }
  89. void _configure_steps_per_execution(int steps_per_execution)
  90. {
  91. _steps_per_execution = tf.Variable(steps_per_execution,
  92. dtype: TF_DataType.TF_INT64,
  93. aggregation: VariableAggregation.OnlyFirstReplica);
  94. }
  95. void _reset_compile_cache()
  96. {
  97. // Used to cache `trainable` attr of `Layer`s for `fit`.
  98. _compiled_trainable_state = _get_trainable_state();
  99. keras.backend._GRAPH = null;
  100. }
  101. void _init_batch_counters()
  102. {
  103. _train_counter = tf.Variable(0L,
  104. dtype: TF_DataType.TF_INT64,
  105. aggregation: VariableAggregation.OnlyFirstReplica);
  106. _test_counter = tf.Variable(0L,
  107. dtype: TF_DataType.TF_INT64,
  108. aggregation: VariableAggregation.OnlyFirstReplica);
  109. _predict_counter = tf.Variable(0L,
  110. dtype: TF_DataType.TF_INT64,
  111. aggregation: VariableAggregation.OnlyFirstReplica);
  112. }
  113. public override List<ILayer> Layers
  114. => _flatten_layers(recursive: false, include_self: false).ToList();
  115. public override List<IVariableV1> TrainableWeights
  116. {
  117. get
  118. {
  119. // skip the assertion of weights created.
  120. var variables = new List<IVariableV1>();
  121. if (!Trainable)
  122. {
  123. return variables;
  124. }
  125. foreach (var trackable_obj in _self_tracked_trackables)
  126. {
  127. if (trackable_obj.Trainable)
  128. variables.AddRange(trackable_obj.TrainableWeights);
  129. }
  130. variables.AddRange(_trainable_weights);
  131. return variables.Distinct().ToList();
  132. }
  133. }
  134. public override List<IVariableV1> NonTrainableWeights
  135. {
  136. get
  137. {
  138. // skip the assertion of weights created.
  139. var variables = new List<IVariableV1>();
  140. foreach (var trackable_obj in _self_tracked_trackables)
  141. {
  142. variables.AddRange(trackable_obj.NonTrainableWeights);
  143. }
  144. if (!Trainable)
  145. {
  146. var trainable_variables = new List<IVariableV1>();
  147. foreach (var trackable_obj in _self_tracked_trackables)
  148. {
  149. variables.AddRange(trackable_obj.TrainableWeights);
  150. }
  151. variables.AddRange(trainable_variables);
  152. variables.AddRange(_trainable_weights);
  153. variables.AddRange(_non_trainable_weights);
  154. }
  155. return variables.Distinct().ToList();
  156. }
  157. }
  158. public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
  159. {
  160. if(save_type == SaveType.SAVEDMODEL)
  161. {
  162. //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`.
  163. }
  164. var children = base._trackable_children(save_type, cache);
  165. return children;
  166. }
  167. public override void SetAttr(string name, object value)
  168. {
  169. // TODO(Rinne): deal with "_self_setattr_tracking".
  170. //if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v)))
  171. //{
  172. // this._base_model_initialized;
  173. //}
  174. base.SetAttr(name, value);
  175. }
  176. void IModel.set_stopTraining_true()
  177. {
  178. stop_training = true;
  179. }
  180. }
  181. }