Browse Source

Implemented support for loading Concatenate layers

model.load_model now supports loading of concatenate layers.
python tensorflow exports concatenate layers in an extra nested array in the manifest so added a check for that in generic_utils.cs.
Concatenate was missing the build=true, this fix prevents the layer being build multiple times.
Concatenate has 2 or more input nodes so List<NodeConfig> was required instead of just NodeConfig in Functional.FromConfig.cs.
Added missing axis JsonProperty attribute for MergeArgs (used by Concatenate)
tags/v0.150.0-BERT-Model
Jucko13 2 years ago
parent
commit
93a242c08a
4 changed files with 35 additions and 15 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
  2. +18
    -12
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  3. +1
    -0
      src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
  4. +12
    -1
      src/TensorFlowNET.Keras/Utils/generic_utils.cs

+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs View File

@@ -1,13 +1,15 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
// TODO: complete the implementation
public class MergeArgs : LayerArgs
public class MergeArgs : AutoSerializeLayerArgs
{
public Tensors Inputs { get; set; }
[JsonProperty("axis")]
public int Axis { get; set; }
}
}

+ 18
- 12
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine
created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
var unprocessed_nodes = new Dictionary<ILayer, List<NodeConfig>>();
// First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
@@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine

static void process_layer(Dictionary<string, ILayer> created_layers,
LayerConfig layer_data,
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
Dictionary<ILayer, List<NodeConfig>> unprocessed_nodes,
Dictionary<ILayer, int> node_count_by_layer)
{
ILayer layer = null;
@@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine

created_layers[layer_name] = layer;
}
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0);

var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data)
{
if (!unprocessed_nodes.ContainsKey(layer))
unprocessed_nodes[layer] = node_data;
unprocessed_nodes[layer] = new List<NodeConfig>() { node_data };
else
unprocessed_nodes.Add(layer, node_data);
unprocessed_nodes[layer].Add(node_data);
}
}

static void process_node(ILayer layer,
NodeConfig node_data,
List<NodeConfig> nodes_data,
Dictionary<string, ILayer> created_layers,
Dictionary<ILayer, int> node_count_by_layer,
Dictionary<(string, int), int> node_index_map)
{

var input_tensors = new List<Tensor>();
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;

var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
for (int i = 0; i < nodes_data.Count; i++)
{
var node_data = nodes_data[i];
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;

var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
}

var output_tensors = layer.Apply(input_tensors);



+ 1
- 0
src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs View File

@@ -39,6 +39,7 @@ namespace Tensorflow.Keras.Layers
shape_set.Add(shape);
}*/
_buildInputShape = input_shape;
built = true;
}

protected override Tensors _merge_function(Tensors inputs)


+ 12
- 1
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -112,12 +112,23 @@ namespace Tensorflow.Keras.Utils
foreach (var token in layersToken)
{
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);

List<NodeConfig> nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array
if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0)
{
nodeConfig = token["inbound_nodes"].ToObject<List<List<NodeConfig>>>().FirstOrDefault() ?? new List<NodeConfig>();
}
else
{
nodeConfig = token["inbound_nodes"].ToObject<List<NodeConfig>>();
}

config.Layers.Add(new LayerConfig()
{
Config = args,
Name = token["name"].ToObject<string>(),
ClassName = token["class_name"].ToObject<string>(),
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
InboundNodes = nodeConfig,
});
}
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();


Loading…
Cancel
Save