From a432807c2351a961a0a2007ae18bbdb28c19503c Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Fri, 30 Dec 2022 12:43:36 -0600 Subject: [PATCH] Remove _layers in Layer. --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 8 ++++---- src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs | 8 ++++---- src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs | 4 ++-- src/TensorFlowNET.Keras/Engine/Functional.cs | 7 +------ src/TensorFlowNET.Keras/Engine/Layer.Layers.cs | 3 +-- src/TensorFlowNET.Keras/Engine/Layer.cs | 8 ++++---- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.cs | 8 ++++---- src/TensorFlowNET.Keras/Engine/Sequential.cs | 8 ++++---- src/TensorFlowNET.Keras/Saving/hdf5_format.cs | 4 ++-- src/TensorFlowNET.Keras/Utils/layer_utils.cs | 6 +++--- 11 files changed, 30 insertions(+), 36 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 271fece0..f77b4a86 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -13,10 +13,10 @@ namespace Tensorflow.Keras List InboundNodes { get; } List OutboundNodes { get; } Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false); - List trainable_variables { get; } - List trainable_weights { get; } - List non_trainable_weights { get; } - Shape output_shape { get; } + List TrainableVariables { get; } + List TrainableWeights { get; } + List NonTrainableWeights { get; } + Shape OutputShape { get; } Shape BatchInputShape { get; } TF_DataType DType { get; } int count_params(); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 041268b7..04fdc7e5 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -67,11 +67,11 @@ namespace Tensorflow public bool Trainable => throw new NotImplementedException(); - public List trainable_variables => throw new NotImplementedException(); - public List trainable_weights => throw new NotImplementedException(); - public List non_trainable_weights => throw new NotImplementedException(); + public List TrainableVariables => throw new NotImplementedException(); + public List TrainableWeights => throw new NotImplementedException(); + public List NonTrainableWeights => throw new NotImplementedException(); - public Shape output_shape => throw new NotImplementedException(); + public Shape OutputShape => throw new NotImplementedException(); public Shape BatchInputShape => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 6615810b..23c40fbf 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine }; var node_conversion_map = new Dictionary(); - foreach (var layer in _layers) + foreach (var layer in _self_tracked_trackables) { var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine } var layer_configs = new List(); - foreach (var layer in _layers) + foreach (var layer in _self_tracked_trackables) { var filtered_inbound_nodes = new List(); foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index def842c3..09a31b94 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -65,13 +65,8 @@ namespace Tensorflow.Keras.Engine } // Keep track of the network's nodes and layers. - var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); + (NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs); - NetworkNodes = nodes; - NodesByDepth = nodes_by_depth; - if (_layers.Count == 0) - _layers = layers; - _self_tracked_trackables = layers; // Build self.input_names and self.output_names. _set_output_names(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index 488c55cb..a2d212cb 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -5,8 +5,7 @@ namespace Tensorflow.Keras.Engine { public partial class Layer { - protected List _layers = new List(); - public virtual List Layers => _layers; + public virtual List Layers => _self_tracked_trackables; protected void StackLayers(params ILayer[] layers) { diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index d417fa44..ba40b1a2 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -63,7 +63,7 @@ namespace Tensorflow.Keras.Engine public bool SupportsMasking { get; set; } protected List _trainable_weights; - public virtual List trainable_variables => _trainable_weights; + public virtual List TrainableVariables => _trainable_weights; protected List _non_trainable_weights; public List non_trainable_variables => _non_trainable_weights; @@ -88,7 +88,7 @@ namespace Tensorflow.Keras.Engine public CallContext CallContext => callContext.Value; public Tensor[] input => inboundNodes[0].input_tensors; public Dictionary> NodesByDepth { get; set; } - public Shape output_shape => inboundNodes[0].Outputs.shape; + public Shape OutputShape => inboundNodes[0].Outputs.shape; protected List _self_tracked_trackables; public Layer(LayerArgs args) @@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine return layer_utils.count_params(this, weights); return 0; } - List ILayer.trainable_weights + List ILayer.TrainableWeights { get { @@ -258,7 +258,7 @@ namespace Tensorflow.Keras.Engine } } - List ILayer.non_trainable_weights + List ILayer.NonTrainableWeights { get { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 31e89c57..f2ff68e9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Engine // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) // The _minimize call does a few extra steps unnecessary in most cases, // such as loss scaling and gradient clipping. - _minimize(tape, optimizer, loss, trainable_variables); + _minimize(tape, optimizer, loss, TrainableVariables); compiled_metrics.update_state(y, y_pred); return metrics.Select(x => (x.Name, x.result())).ToList(); diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index baf68229..162d06c5 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -74,7 +74,7 @@ namespace Tensorflow.Keras.Engine public override List Layers => _flatten_layers(recursive: false, include_self: false).ToList(); - public override List trainable_variables + public override List TrainableVariables { get { @@ -88,13 +88,13 @@ namespace Tensorflow.Keras.Engine foreach (var trackable_obj in _self_tracked_trackables) { if (trackable_obj.Trainable) - variables.AddRange(trackable_obj.trainable_variables); + variables.AddRange(trackable_obj.TrainableVariables); } - foreach (var layer in _layers) + foreach (var layer in _self_tracked_trackables) { if (layer.Trainable) - variables.AddRange(layer.trainable_variables); + variables.AddRange(layer.TrainableVariables); } // variables.AddRange(_trainable_weights); diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 47e6c3f7..681ab2f0 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Engine { built = false; var set_inputs = false; - if (_layers.Count == 0) + if (_self_tracked_trackables.Count == 0) { if (layer is InputLayer) { @@ -128,7 +128,7 @@ namespace Tensorflow.Keras.Engine void _handle_deferred_layer_dependencies(params ILayer[] layers) { - _layers.AddRange(layers); + _self_tracked_trackables.AddRange(layers); } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) @@ -156,12 +156,12 @@ namespace Tensorflow.Keras.Engine ops.init_scope(); var inputs = keras.Input(batch_input_shape: input_shape, dtype: input_dtype, - name: $"{_layers[0].Name}_input"); + name: $"{_self_tracked_trackables[0].Name}_input"); Tensors layer_input = inputs; Tensors layer_output = null; Tensors outputs = null; List created_nodes = new List(); - foreach (var layer in _layers) + foreach (var layer in _self_tracked_trackables) { clear_previously_created_nodes(layer, _created_nodes); layer_output = layer.Apply(layer_input); diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index a3705dfb..b04391be 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -338,8 +338,8 @@ namespace Tensorflow.Keras.Saving public static List _legacy_weights(ILayer layer) { - var weights = layer.trainable_weights.Select(x => x).ToList(); - weights.AddRange(layer.non_trainable_weights); + var weights = layer.TrainableWeights.Select(x => x).ToList(); + weights.AddRange(layer.NonTrainableWeights); return weights; } } diff --git a/src/TensorFlowNET.Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Keras/Utils/layer_utils.cs index 998086f6..3c38a6d1 100644 --- a/src/TensorFlowNET.Keras/Utils/layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/layer_utils.cs @@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Utils print(string.Join("", range(line_length).Select(x => "_"))); } - var trainable_count = count_params(model, model.trainable_variables); + var trainable_count = count_params(model, model.TrainableVariables); var non_trainable_count = count_params(model, model.non_trainable_variables); print($"Total params: {trainable_count + non_trainable_count}"); @@ -137,7 +137,7 @@ namespace Tensorflow.Keras.Utils var fields = new string[] { $"{name} ({layer.GetType().Name})", - $"{layer.output_shape}", + $"{layer.OutputShape}", $"{layer.count_params()}" }; @@ -164,7 +164,7 @@ namespace Tensorflow.Keras.Utils var fields = new string[] { $"{name}({layer.GetType().Name})", - $"{layer.output_shape}", + $"{layer.OutputShape}", $"{layer.count_params()}", first_connection };