From 18d2512ee5c051d66d7c558991c216861df57738 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 1 Jan 2022 10:13:33 -0600 Subject: [PATCH] fix keras sequential. --- src/TensorFlowNET.Core/Binding.Util.cs | 32 +++------ src/TensorFlowNET.Core/DisposableObject.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.Apply.cs | 17 ++--- src/TensorFlowNET.Keras/Engine/Layer.cs | 5 +- src/TensorFlowNET.Keras/Engine/Sequential.cs | 69 ++++++++++++++++++- src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 2 +- 6 files changed, 87 insertions(+), 40 deletions(-) diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 31902f14..5d9d799d 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -98,35 +98,23 @@ namespace Tensorflow default: return obj?.ToString() ?? "null"; } - - object[] toObjectArray(Array arr) - { - var len = arr.LongLength; - var ret = new object[len]; - for (long i = 0; i < len; i++) - { - ret[i] = arr.GetValue(i); - } - - return ret; - } } - private static TextWriter writer = null; + private static TextWriter _writer = Console.Out; public static TextWriter tf_output_redirect { set { - var originWriter = writer ?? Console.Out; - originWriter.Flush(); - if (originWriter is StringWriter) - (originWriter as StringWriter).GetStringBuilder().Clear(); - writer = value; - } - get - { - return writer ?? Console.Out; + if(_writer != null) + { + _writer.Flush(); + if (_writer is StringWriter sw) + sw.GetStringBuilder().Clear(); + } + + _writer = value; } + get => _writer ?? Console.Out; } public static void print(object obj) diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 60f39b60..3c70739b 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -48,7 +48,7 @@ namespace Tensorflow } // free unmanaged memory - // if (_handle != IntPtr.Zero) + if (_handle != IntPtr.Zero) { // Call the appropriate methods to clean up // unmanaged resources here. diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index fb37a89c..7d3721f1 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine /// public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) { - callContext = callContext?.Value != null ? callContext : new ThreadLocal() - { - Value = new CallContext() - }; + if (callContext.Value == null) + callContext.Value = new CallContext(); if (_in_functional_construction_mode(inputs)) return FunctionalConstructionCall(inputs); - Tensors outputs = null; - var eager = tf.executing_eagerly(); using var ctxManager = CallContext.enter(build_graph: false); - string nameScope = ""; - if (eager) - nameScope = Name; - else - nameScope = _name_scope(); - + string nameScope = eager ? name : _name_scope(); var scope = ops.name_scope(nameScope); scope.__enter__(); if (!built) MaybeBuild(inputs); - outputs = Call(inputs, state: state, training: training); + var outputs = Call(inputs, state: state, training: training); // memory leak // _set_connectivity_metadata_(inputs, outputs); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index e9d58b6f..7496e071 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine List outboundNodes; public List OutboundNodes => outboundNodes; - ThreadLocal callContext; + ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; public Tensor[] input => inboundNodes[0].input_tensors; public Dictionary> NodesByDepth { get; set; } public Shape output_shape => inboundNodes[0].Outputs.shape; + protected List _self_tracked_trackables; + public Layer(LayerArgs args) { this.args = args; @@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine non_trainable_weights = new List(); computePreviousMask = false; updates = new List(); + _self_tracked_trackables = new List(); inboundNodes = new List(); outboundNodes = new List(); diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index d41a5572..3d7832b8 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using System.Linq; using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; @@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine bool _auto_track_sub_layers; Shape _inferred_input_shape; bool _has_explicit_input_shape; - + bool _graph_initialized; public Shape output_shape => outputs[0].shape; + List _created_nodes; public Sequential(SequentialArgs args) : base(args.Inputs, args.Outputs, name: args.Name) @@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine _auto_track_sub_layers = false; _has_explicit_input_shape = false; _is_graph_network = false; + _created_nodes = new List(); // Add to the model any layers passed to the constructor. if (args.Layers != null) { foreach (var layer in args.Layers) - add(layer as Layer); + add(layer); } } @@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine } else { + _self_tracked_trackables.add(layer); + _handle_deferred_layer_dependencies(layer); + } + } + void _handle_deferred_layer_dependencies(params ILayer[] layers) + { + _layers.AddRange(layers); + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + if (!_has_explicit_input_shape) + { + _build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype); + } + + if(_graph_initialized) + { + if (!built) + _init_graph_network(this.inputs, outputs); + return base.Call(inputs, state, training); + } + + return base.Call(inputs, state, training); + } + + void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) + { + ops.init_scope(); + var inputs = keras.Input(batch_input_shape: input_shape, + dtype: input_dtype, + name: $"{_layers[0].Name}_input"); + Tensors layer_input = inputs; + Tensors layer_output = null; + Tensors outputs = null; + + foreach (var layer in _layers) + { + clear_previously_created_nodes(layer, _created_nodes); + layer_output = layer.Apply(layer_input); + // Keep track of nodes just created above + track_nodes_created_by_last_call(layer, _created_nodes); + layer_input = layer_output; + outputs = layer_output; + } + _init_graph_network(inputs, outputs); + _graph_initialized = true; + _inferred_input_shape = input_shape; + } + + void clear_previously_created_nodes(ILayer layer, List created_nodes) + { + + } + + void track_nodes_created_by_last_call(ILayer layer, List created_nodes) + { + var node = layer.InboundNodes.Last(); + created_nodes.Add(node); + foreach(var prev_layer in node.iterate_inbound()) + { + created_nodes.add(prev_layer.Item1.OutboundNodes.Last()); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index fb813455..f3956811 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers var rank = inputs.rank; if (rank > 2) { - throw new NotImplementedException("call rank > 2"); + outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } }); } else {