From d1fc44dcef9c148887d89895815921fa7a1e5c64 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 1 Jan 2022 20:01:25 -0600 Subject: [PATCH] Add InboundLayers to Node --- .../Contexts/Context.ExecuteOp.cs | 2 +- src/TensorFlowNET.Core/Keras/Engine/INode.cs | 1 + src/TensorFlowNET.Core/NumPy/ShapeHelper.cs | 7 ++++++ src/TensorFlowNET.Core/Numpy/Shape.cs | 6 +++++ .../Operations/Operation.cs | 2 +- src/TensorFlowNET.Keras/Engine/Functional.cs | 3 ++- .../Engine/Node.IterateInbound.cs | 4 ++++ src/TensorFlowNET.Keras/Engine/Sequential.cs | 24 +++++++++++++++---- 8 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs index 5b256455..ac1cd866 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs @@ -78,7 +78,7 @@ namespace Tensorflow.Contexts if (args.GetGradientAttrs == null) { attrs = new Dictionary(); - attrs["T"] = op.get_attr("T"); + attrs["T"] = op.dtype; } else { diff --git a/src/TensorFlowNET.Core/Keras/Engine/INode.cs b/src/TensorFlowNET.Core/Keras/Engine/INode.cs index 83e1bb00..bd778f6c 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/INode.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/INode.cs @@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine ILayer Layer { get; } List KerasInputs { get; set; } INode[] ParentNodes { get; } + ILayer[] InboundLayers { get; } IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); bool is_input { get; } List serialize(Func make_node_key, Dictionary node_conversion_map); diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index 832a6658..9c9ae7d3 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -88,6 +88,13 @@ namespace Tensorflow.NumPy public static bool Equals(Shape shape, object target) { + if (shape is null && target is null) + return true; + else if (shape is null && target is not null) + return false; + else if (shape is not null && target is null) + return false; + switch (target) { case Shape shape1: diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index dd2981e7..bc79fefc 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -253,5 +253,11 @@ namespace Tensorflow public override bool Equals(object obj) => ShapeHelper.Equals(this, obj); public override string ToString() => ShapeHelper.ToString(this); + + public static bool operator ==(Shape a, Shape b) + => ShapeHelper.Equals(a, b); + + public static bool operator !=(Shape a, Shape b) + => !ShapeHelper.Equals(a, b); } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index cb018700..fb9a4a27 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -55,7 +55,7 @@ namespace Tensorflow public int _id_value { get; set; } public Operation op => this; - public TF_DataType dtype => TF_DataType.DtInvalid; + public TF_DataType dtype => output.dtype; public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 1d9396f4..01d84794 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -69,7 +69,8 @@ namespace Tensorflow.Keras.Engine NetworkNodes = nodes; NodesByDepth = nodes_by_depth; - _layers = layers; + if (_layers.Count == 0) + _layers = layers; // Build self.input_names and self.output_names. _set_output_names(); diff --git a/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs b/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs index 359d36c9..5da2fa44 100644 --- a/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs +++ b/src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs @@ -1,9 +1,13 @@ using System.Collections.Generic; +using System.Linq; namespace Tensorflow.Keras.Engine { public partial class Node { + public ILayer[] InboundLayers + => iterate_inbound().Select(x => x.Item1).ToArray(); + public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound() { foreach (var kt in KerasInputs) diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 3d7832b8..7d8c77fe 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -150,6 +150,9 @@ namespace Tensorflow.Keras.Engine void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) { + if (_inferred_input_shape == input_shape) + return; + ops.init_scope(); var inputs = keras.Input(batch_input_shape: input_shape, dtype: input_dtype, @@ -157,16 +160,17 @@ namespace Tensorflow.Keras.Engine Tensors layer_input = inputs; Tensors layer_output = null; Tensors outputs = null; - + List created_nodes = new List(); foreach (var layer in _layers) { clear_previously_created_nodes(layer, _created_nodes); layer_output = layer.Apply(layer_input); // Keep track of nodes just created above - track_nodes_created_by_last_call(layer, _created_nodes); + track_nodes_created_by_last_call(layer, created_nodes); layer_input = layer_output; outputs = layer_output; } + _created_nodes = created_nodes; _init_graph_network(inputs, outputs); _graph_initialized = true; _inferred_input_shape = input_shape; @@ -174,16 +178,28 @@ namespace Tensorflow.Keras.Engine void clear_previously_created_nodes(ILayer layer, List created_nodes) { + foreach(var node in layer.InboundNodes) + { + foreach(var prev_layer in node.InboundLayers) + { + var outNodes = prev_layer.OutboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); + prev_layer.OutboundNodes.Clear(); + prev_layer.OutboundNodes.AddRange(outNodes); + } + } + var inNodes = layer.InboundNodes.Where(x => !created_nodes.Contains(x)).ToArray(); + layer.InboundNodes.Clear(); + layer.InboundNodes.AddRange(inNodes); } void track_nodes_created_by_last_call(ILayer layer, List created_nodes) { var node = layer.InboundNodes.Last(); created_nodes.Add(node); - foreach(var prev_layer in node.iterate_inbound()) + foreach(var prev_layer in node.InboundLayers) { - created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); + created_nodes.add(prev_layer.OutboundNodes.Last()); } } }