diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 31902f14..5d9d799d 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -98,35 +98,23 @@ namespace Tensorflow
default:
return obj?.ToString() ?? "null";
}
-
- object[] toObjectArray(Array arr)
- {
- var len = arr.LongLength;
- var ret = new object[len];
- for (long i = 0; i < len; i++)
- {
- ret[i] = arr.GetValue(i);
- }
-
- return ret;
- }
}
- private static TextWriter writer = null;
+ private static TextWriter _writer = Console.Out;
public static TextWriter tf_output_redirect {
set
{
- var originWriter = writer ?? Console.Out;
- originWriter.Flush();
- if (originWriter is StringWriter)
- (originWriter as StringWriter).GetStringBuilder().Clear();
- writer = value;
- }
- get
- {
- return writer ?? Console.Out;
+ if(_writer != null)
+ {
+ _writer.Flush();
+ if (_writer is StringWriter sw)
+ sw.GetStringBuilder().Clear();
+ }
+
+ _writer = value;
}
+ get => _writer ?? Console.Out;
}
public static void print(object obj)
diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs
index 60f39b60..3c70739b 100644
--- a/src/TensorFlowNET.Core/DisposableObject.cs
+++ b/src/TensorFlowNET.Core/DisposableObject.cs
@@ -48,7 +48,7 @@ namespace Tensorflow
}
// free unmanaged memory
- // if (_handle != IntPtr.Zero)
+ if (_handle != IntPtr.Zero)
{
// Call the appropriate methods to clean up
// unmanaged resources here.
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
index fb37a89c..7d3721f1 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
@@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine
///
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
{
- callContext = callContext?.Value != null ? callContext : new ThreadLocal()
- {
- Value = new CallContext()
- };
+ if (callContext.Value == null)
+ callContext.Value = new CallContext();
if (_in_functional_construction_mode(inputs))
return FunctionalConstructionCall(inputs);
- Tensors outputs = null;
-
var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter(build_graph: false);
- string nameScope = "";
- if (eager)
- nameScope = Name;
- else
- nameScope = _name_scope();
-
+ string nameScope = eager ? name : _name_scope();
var scope = ops.name_scope(nameScope);
scope.__enter__();
if (!built)
MaybeBuild(inputs);
- outputs = Call(inputs, state: state, training: training);
+ var outputs = Call(inputs, state: state, training: training);
// memory leak
// _set_connectivity_metadata_(inputs, outputs);
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs
index e9d58b6f..7496e071 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.cs
@@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine
List outboundNodes;
public List OutboundNodes => outboundNodes;
- ThreadLocal callContext;
+ ThreadLocal callContext = new ThreadLocal();
public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;
public Dictionary> NodesByDepth { get; set; }
public Shape output_shape => inboundNodes[0].Outputs.shape;
+ protected List _self_tracked_trackables;
+
public Layer(LayerArgs args)
{
this.args = args;
@@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine
non_trainable_weights = new List();
computePreviousMask = false;
updates = new List();
+ _self_tracked_trackables = new List();
inboundNodes = new List();
outboundNodes = new List();
diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs
index d41a5572..3d7832b8 100644
--- a/src/TensorFlowNET.Keras/Engine/Sequential.cs
+++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System;
using System.Linq;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
@@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine
bool _auto_track_sub_layers;
Shape _inferred_input_shape;
bool _has_explicit_input_shape;
-
+ bool _graph_initialized;
public Shape output_shape => outputs[0].shape;
+ List _created_nodes;
public Sequential(SequentialArgs args)
: base(args.Inputs, args.Outputs, name: args.Name)
@@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine
_auto_track_sub_layers = false;
_has_explicit_input_shape = false;
_is_graph_network = false;
+ _created_nodes = new List();
// Add to the model any layers passed to the constructor.
if (args.Layers != null)
{
foreach (var layer in args.Layers)
- add(layer as Layer);
+ add(layer);
}
}
@@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine
}
else
{
+ _self_tracked_trackables.add(layer);
+ _handle_deferred_layer_dependencies(layer);
+ }
+ }
+ void _handle_deferred_layer_dependencies(params ILayer[] layers)
+ {
+ _layers.AddRange(layers);
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ if (!_has_explicit_input_shape)
+ {
+ _build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype);
+ }
+
+ if(_graph_initialized)
+ {
+ if (!built)
+ _init_graph_network(this.inputs, outputs);
+ return base.Call(inputs, state, training);
+ }
+
+ return base.Call(inputs, state, training);
+ }
+
+ void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)
+ {
+ ops.init_scope();
+ var inputs = keras.Input(batch_input_shape: input_shape,
+ dtype: input_dtype,
+ name: $"{_layers[0].Name}_input");
+ Tensors layer_input = inputs;
+ Tensors layer_output = null;
+ Tensors outputs = null;
+
+ foreach (var layer in _layers)
+ {
+ clear_previously_created_nodes(layer, _created_nodes);
+ layer_output = layer.Apply(layer_input);
+ // Keep track of nodes just created above
+ track_nodes_created_by_last_call(layer, _created_nodes);
+ layer_input = layer_output;
+ outputs = layer_output;
+ }
+ _init_graph_network(inputs, outputs);
+ _graph_initialized = true;
+ _inferred_input_shape = input_shape;
+ }
+
+ void clear_previously_created_nodes(ILayer layer, List created_nodes)
+ {
+
+ }
+
+ void track_nodes_created_by_last_call(ILayer layer, List created_nodes)
+ {
+ var node = layer.InboundNodes.Last();
+ created_nodes.Add(node);
+ foreach(var prev_layer in node.iterate_inbound())
+ {
+ created_nodes.add(prev_layer.Item1.OutboundNodes.Last());
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
index fb813455..f3956811 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers
var rank = inputs.rank;
if (rank > 2)
{
- throw new NotImplementedException("call rank > 2");
+ outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } });
}
else
{