using System.Diagnostics; using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; using Tensorflow.Util; namespace Tensorflow.Keras.Engine { /// /// `Model` groups layers into an object with training and inference features. /// public partial class Model : Layer, IModel { #pragma warning disable CS0169 // The field 'Model._cloning' is never used bool _cloning; #pragma warning restore CS0169 // The field 'Model._cloning' is never used #pragma warning disable CS0108 // Member hides inherited member; missing new keyword #pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used bool _is_compiled; #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used #pragma warning restore CS0108 // Member hides inherited member; missing new keyword ILossFunc loss; IOptimizer optimizer; IVariableV1 _steps_per_execution; protected bool _is_graph_network; public Tensors inputs; protected Tensors outputs; protected List input_names; public string[] output_names; IVariableV1 _train_counter; IVariableV1 _test_counter; IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; TensorSpec _saved_model_inputs_spec; public bool IsGraphNetwork => _is_graph_network; public IOptimizer Optimizer { get => optimizer; set => optimizer = value; } public bool Stop_training { get => stop_training; set => stop_training = value; } public Model(ModelArgs args) : base(args) { _init_batch_counters(); } public void _set_inputs(TensorSpec inputs) { _set_save_spec(inputs); } internal void _set_save_spec(TensorSpec inputs) { if(_saved_model_inputs_spec is not null) { return; } var input_names = this.input_names; if(input_names is null || input_names.Count == 0) { input_names = compile_utils.create_pseudo_input_names(inputs); } var flat_inputs = nest.flatten(inputs); List specs = new(); foreach(var (name, tensor) in zip(input_names, flat_inputs)) { specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name)); } var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec; Debug.Assert(specs is not null); _saved_model_inputs_spec = packed_specs; if(this is Sequential && _buildInputShape is null) { _buildInputShape = nest.map_structure(x => x is null ? null : x.shape, packed_specs); } } internal override void Initialize(LayerArgs args) { _init_batch_counters(); base.Initialize(args); } void _configure_steps_per_execution(int steps_per_execution) { _steps_per_execution = tf.Variable(steps_per_execution, dtype: TF_DataType.TF_INT64, aggregation: VariableAggregation.OnlyFirstReplica); } void _reset_compile_cache() { // Used to cache `trainable` attr of `Layer`s for `fit`. _compiled_trainable_state = _get_trainable_state(); keras.backend._GRAPH = null; } void _init_batch_counters() { _train_counter = tf.Variable(0L, dtype: TF_DataType.TF_INT64, aggregation: VariableAggregation.OnlyFirstReplica); _test_counter = tf.Variable(0L, dtype: TF_DataType.TF_INT64, aggregation: VariableAggregation.OnlyFirstReplica); _predict_counter = tf.Variable(0L, dtype: TF_DataType.TF_INT64, aggregation: VariableAggregation.OnlyFirstReplica); } public override List Layers => _flatten_layers(recursive: false, include_self: false).ToList(); public override List TrainableWeights { get { // skip the assertion of weights created. var variables = new List(); if (!Trainable) { return variables; } foreach (var trackable_obj in _self_tracked_trackables) { if (trackable_obj.Trainable) variables.AddRange(trackable_obj.TrainableWeights); } variables.AddRange(_trainable_weights); return variables.Distinct().ToList(); } } public override List NonTrainableWeights { get { // skip the assertion of weights created. var variables = new List(); foreach (var trackable_obj in _self_tracked_trackables) { variables.AddRange(trackable_obj.NonTrainableWeights); } if (!Trainable) { var trainable_variables = new List(); foreach (var trackable_obj in _self_tracked_trackables) { variables.AddRange(trackable_obj.TrainableWeights); } variables.AddRange(trainable_variables); variables.AddRange(_trainable_weights); variables.AddRange(_non_trainable_weights); } return variables.Distinct().ToList(); } } public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { if(save_type == SaveType.SAVEDMODEL) { //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. } var children = base._trackable_children(save_type, cache); return children; } public override void SetAttr(string name, object value) { // TODO(Rinne): deal with "_self_setattr_tracking". //if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v))) //{ // this._base_model_initialized; //} base.SetAttr(name, value); } void IModel.set_stopTraining_true() { stop_training = true; } } }