Browse Source

fix keras sequential.

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
18d2512ee5
6 changed files with 87 additions and 40 deletions
  1. +10
    -22
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/DisposableObject.cs
  3. +4
    -13
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  4. +4
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  5. +67
    -2
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  6. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs

+ 10
- 22
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -98,35 +98,23 @@ namespace Tensorflow
default:
return obj?.ToString() ?? "null";
}

object[] toObjectArray(Array arr)
{
var len = arr.LongLength;
var ret = new object[len];
for (long i = 0; i < len; i++)
{
ret[i] = arr.GetValue(i);
}

return ret;
}
}

private static TextWriter writer = null;
private static TextWriter _writer = Console.Out;

public static TextWriter tf_output_redirect {
set
{
var originWriter = writer ?? Console.Out;
originWriter.Flush();
if (originWriter is StringWriter)
(originWriter as StringWriter).GetStringBuilder().Clear();
writer = value;
}
get
{
return writer ?? Console.Out;
if(_writer != null)
{
_writer.Flush();
if (_writer is StringWriter sw)
sw.GetStringBuilder().Clear();
}

_writer = value;
}
get => _writer ?? Console.Out;
}

public static void print(object obj)


+ 1
- 1
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow
}

// free unmanaged memory
// if (_handle != IntPtr.Zero)
if (_handle != IntPtr.Zero)
{
// Call the appropriate methods to clean up
// unmanaged resources here.


+ 4
- 13
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -14,32 +14,23 @@ namespace Tensorflow.Keras.Engine
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
{
callContext = callContext?.Value != null ? callContext : new ThreadLocal<CallContext>()
{
Value = new CallContext()
};
if (callContext.Value == null)
callContext.Value = new CallContext();

if (_in_functional_construction_mode(inputs))
return FunctionalConstructionCall(inputs);

Tensors outputs = null;

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter(build_graph: false);

string nameScope = "";
if (eager)
nameScope = Name;
else
nameScope = _name_scope();

string nameScope = eager ? name : _name_scope();
var scope = ops.name_scope(nameScope);
scope.__enter__();

if (!built)
MaybeBuild(inputs);

outputs = Call(inputs, state: state, training: training);
var outputs = Call(inputs, state: state, training: training);

// memory leak
// _set_connectivity_metadata_(inputs, outputs);


+ 4
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -84,11 +84,13 @@ namespace Tensorflow.Keras.Engine
List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes;

ThreadLocal<CallContext> callContext;
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;
public Dictionary<int, List<INode>> NodesByDepth { get; set; }
public Shape output_shape => inboundNodes[0].Outputs.shape;
protected List<ILayer> _self_tracked_trackables;

public Layer(LayerArgs args)
{
this.args = args;
@@ -106,6 +108,7 @@ namespace Tensorflow.Keras.Engine
non_trainable_weights = new List<IVariableV1>();
computePreviousMask = false;
updates = new List<Operation>();
_self_tracked_trackables = new List<ILayer>();

inboundNodes = new List<INode>();
outboundNodes = new List<INode>();


+ 67
- 2
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
@@ -35,8 +36,9 @@ namespace Tensorflow.Keras.Engine
bool _auto_track_sub_layers;
Shape _inferred_input_shape;
bool _has_explicit_input_shape;
bool _graph_initialized;
public Shape output_shape => outputs[0].shape;
List<INode> _created_nodes;

public Sequential(SequentialArgs args)
: base(args.Inputs, args.Outputs, name: args.Name)
@@ -49,12 +51,13 @@ namespace Tensorflow.Keras.Engine
_auto_track_sub_layers = false;
_has_explicit_input_shape = false;
_is_graph_network = false;
_created_nodes = new List<INode>();

// Add to the model any layers passed to the constructor.
if (args.Layers != null)
{
foreach (var layer in args.Layers)
add(layer as Layer);
add(layer);
}
}

@@ -118,7 +121,69 @@ namespace Tensorflow.Keras.Engine
}
else
{
_self_tracked_trackables.add(layer);
_handle_deferred_layer_dependencies(layer);
}
}

void _handle_deferred_layer_dependencies(params ILayer[] layers)
{
_layers.AddRange(layers);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (!_has_explicit_input_shape)
{
_build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype);
}

if(_graph_initialized)
{
if (!built)
_init_graph_network(this.inputs, outputs);
return base.Call(inputs, state, training);
}

return base.Call(inputs, state, training);
}

void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)
{
ops.init_scope();
var inputs = keras.Input(batch_input_shape: input_shape,
dtype: input_dtype,
name: $"{_layers[0].Name}_input");
Tensors layer_input = inputs;
Tensors layer_output = null;
Tensors outputs = null;
foreach (var layer in _layers)
{
clear_previously_created_nodes(layer, _created_nodes);
layer_output = layer.Apply(layer_input);
// Keep track of nodes just created above
track_nodes_created_by_last_call(layer, _created_nodes);
layer_input = layer_output;
outputs = layer_output;
}
_init_graph_network(inputs, outputs);
_graph_initialized = true;
_inferred_input_shape = input_shape;
}

void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes)
{

}

void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes)
{
var node = layer.InboundNodes.Last();
created_nodes.Add(node);
foreach(var prev_layer in node.iterate_inbound())
{
created_nodes.add(prev_layer.Item1.OutboundNodes.Last());
}
}
}


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Layers
var rank = inputs.rank;
if (rank > 2)
{
throw new NotImplementedException("call rank > 2");
outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } });
}
else
{


Loading…
Cancel
Save