diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs b/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs index e088fdaf..dd037e24 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.FlattenLayers.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine yield return this; var seen_object_ids = new List(); - var deque = new Queue(_layers); + var deque = new Queue(_self_tracked_trackables); while (!deque.empty()) { var layer_or_container = deque.Dequeue(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index 32535838..f38750c2 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -6,7 +6,7 @@ namespace Tensorflow.Keras.Engine public partial class Layer { protected List _layers = new List(); - public List Layers => _layers; + public virtual List Layers => _layers; protected void StackLayers(params ILayer[] layers) { diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 4ae94b3d..baf68229 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; @@ -70,6 +71,9 @@ namespace Tensorflow.Keras.Engine aggregation: VariableAggregation.OnlyFirstReplica); } + public override List Layers + => _flatten_layers(recursive: false, include_self: false).ToList(); + public override List trainable_variables { get diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 7d8c77fe..47e6c3f7 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -202,5 +202,8 @@ namespace Tensorflow.Keras.Engine created_nodes.add(prev_layer.OutboundNodes.Last()); } } + + public override List Layers + => base.Layers.Where(x => x is not InputLayer).ToList(); } }