diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs index 0140b3dd..9bcf1908 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs @@ -1,13 +1,15 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { // TODO: complete the implementation - public class MergeArgs : LayerArgs + public class MergeArgs : AutoSerializeLayerArgs { public Tensors Inputs { get; set; } + [JsonProperty("axis")] public int Axis { get; set; } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs index 7b826af8..375fc910 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine created_layers = created_layers ?? new Dictionary(); var node_index_map = new Dictionary<(string, int), int>(); var node_count_by_layer = new Dictionary(); - var unprocessed_nodes = new Dictionary(); + var unprocessed_nodes = new Dictionary>(); // First, we create all layers and enqueue nodes to be processed foreach (var layer_data in config.Layers) process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); @@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine static void process_layer(Dictionary created_layers, LayerConfig layer_data, - Dictionary unprocessed_nodes, + Dictionary> unprocessed_nodes, Dictionary node_count_by_layer) { ILayer layer = null; @@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine created_layers[layer_name] = layer; } - node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0; + node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0); var inbound_nodes_data = layer_data.InboundNodes; foreach (var node_data in inbound_nodes_data) { if (!unprocessed_nodes.ContainsKey(layer)) - unprocessed_nodes[layer] = node_data; + unprocessed_nodes[layer] = new List() { node_data }; else - unprocessed_nodes.Add(layer, node_data); + unprocessed_nodes[layer].Add(node_data); } } static void process_node(ILayer layer, - NodeConfig node_data, + List nodes_data, Dictionary created_layers, Dictionary node_count_by_layer, Dictionary<(string, int), int> node_index_map) { + var input_tensors = new List(); - var inbound_layer_name = node_data.Name; - var inbound_node_index = node_data.NodeIndex; - var inbound_tensor_index = node_data.TensorIndex; - var inbound_layer = created_layers[inbound_layer_name]; - var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; - input_tensors.Add(inbound_node.Outputs[inbound_node_index]); + for (int i = 0; i < nodes_data.Count; i++) + { + var node_data = nodes_data[i]; + var inbound_layer_name = node_data.Name; + var inbound_node_index = node_data.NodeIndex; + var inbound_tensor_index = node_data.TensorIndex; + + var inbound_layer = created_layers[inbound_layer_name]; + var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; + input_tensors.Add(inbound_node.Outputs[inbound_node_index]); + } var output_tensors = layer.Apply(input_tensors); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index a2a8286b..fa82426c 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -39,6 +39,7 @@ namespace Tensorflow.Keras.Layers shape_set.Add(shape); }*/ _buildInputShape = input_shape; + built = true; } protected override Tensors _merge_function(Tensors inputs) diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 5402f499..20937e2e 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -112,12 +112,23 @@ namespace Tensorflow.Keras.Utils foreach (var token in layersToken) { var args = deserialize_layer_args(token["class_name"].ToObject(), token["config"]); + + List nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array + if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0) + { + nodeConfig = token["inbound_nodes"].ToObject>>().FirstOrDefault() ?? new List(); + } + else + { + nodeConfig = token["inbound_nodes"].ToObject>(); + } + config.Layers.Add(new LayerConfig() { Config = args, Name = token["name"].ToObject(), ClassName = token["class_name"].ToObject(), - InboundNodes = token["inbound_nodes"].ToObject>() + InboundNodes = nodeConfig, }); } config.InputLayers = json["input_layers"].ToObject>();