Browse Source

Add InboundLayers to Node

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
d1fc44dcef
8 changed files with 42 additions and 7 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/INode.cs
  3. +7
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  6. +2
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  7. +4
    -0
      src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs
  8. +20
    -4
      src/TensorFlowNET.Keras/Engine/Sequential.cs

+ 1
- 1
src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow.Contexts
if (args.GetGradientAttrs == null)
{
attrs = new Dictionary<string, object>();
attrs["T"] = op.get_attr<TF_DataType>("T");
attrs["T"] = op.dtype;
}
else
{


+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/INode.cs View File

@@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine
ILayer Layer { get; }
List<Tensor> KerasInputs { get; set; }
INode[] ParentNodes { get; }
ILayer[] InboundLayers { get; }
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
bool is_input { get; }
List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);


+ 7
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -88,6 +88,13 @@ namespace Tensorflow.NumPy

public static bool Equals(Shape shape, object target)
{
if (shape is null && target is null)
return true;
else if (shape is null && target is not null)
return false;
else if (shape is not null && target is null)
return false;

switch (target)
{
case Shape shape1:


+ 6
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -253,5 +253,11 @@ namespace Tensorflow
public override bool Equals(object obj) => ShapeHelper.Equals(this, obj);

public override string ToString() => ShapeHelper.ToString(this);

public static bool operator ==(Shape a, Shape b)
=> ShapeHelper.Equals(a, b);

public static bool operator !=(Shape a, Shape b)
=> !ShapeHelper.Equals(a, b);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow

public int _id_value { get; set; }
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
public TF_DataType dtype => output.dtype;
public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle));



+ 2
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -69,7 +69,8 @@ namespace Tensorflow.Keras.Engine

NetworkNodes = nodes;
NodesByDepth = nodes_by_depth;
_layers = layers;
if (_layers.Count == 0)
_layers = layers;

// Build self.input_names and self.output_names.
_set_output_names();


+ 4
- 0
src/TensorFlowNET.Keras/Engine/Node.IterateInbound.cs View File

@@ -1,9 +1,13 @@
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Engine
{
public partial class Node
{
public ILayer[] InboundLayers
=> iterate_inbound().Select(x => x.Item1).ToArray();

public IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound()
{
foreach (var kt in KerasInputs)


+ 20
- 4
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -150,6 +150,9 @@ namespace Tensorflow.Keras.Engine

void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)
{
if (_inferred_input_shape == input_shape)
return;

ops.init_scope();
var inputs = keras.Input(batch_input_shape: input_shape,
dtype: input_dtype,
@@ -157,16 +160,17 @@ namespace Tensorflow.Keras.Engine
Tensors layer_input = inputs;
Tensors layer_output = null;
Tensors outputs = null;
List<INode> created_nodes = new List<INode>();
foreach (var layer in _layers)
{
clear_previously_created_nodes(layer, _created_nodes);
layer_output = layer.Apply(layer_input);
// Keep track of nodes just created above
track_nodes_created_by_last_call(layer, _created_nodes);
track_nodes_created_by_last_call(layer, created_nodes);
layer_input = layer_output;
outputs = layer_output;
}
_created_nodes = created_nodes;
_init_graph_network(inputs, outputs);
_graph_initialized = true;
_inferred_input_shape = input_shape;
@@ -174,16 +178,28 @@ namespace Tensorflow.Keras.Engine

void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes)
{
foreach(var node in layer.InboundNodes)
{
foreach(var prev_layer in node.InboundLayers)
{
var outNodes = prev_layer.OutboundNodes.Where(x => !created_nodes.Contains(x)).ToArray();
prev_layer.OutboundNodes.Clear();
prev_layer.OutboundNodes.AddRange(outNodes);
}
}

var inNodes = layer.InboundNodes.Where(x => !created_nodes.Contains(x)).ToArray();
layer.InboundNodes.Clear();
layer.InboundNodes.AddRange(inNodes);
}

void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes)
{
var node = layer.InboundNodes.Last();
created_nodes.Add(node);
foreach(var prev_layer in node.iterate_inbound())
foreach(var prev_layer in node.InboundLayers)
{
created_nodes.add(prev_layer.Item1.OutboundNodes.Last());
created_nodes.add(prev_layer.OutboundNodes.Last());
}
}
}


Loading…
Cancel
Save