@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine
created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
var unprocessed_nodes = new Dictionary<ILayer, List< NodeConfig> >();
// First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
@@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine
static void process_layer(Dictionary<string, ILayer> created_layers,
LayerConfig layer_data,
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
Dictionary<ILayer, List< NodeConfig> > unprocessed_nodes,
Dictionary<ILayer, int> node_count_by_layer)
{
ILayer layer = null;
@@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine
created_layers[layer_name] = layer;
}
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
node_count_by_layer[layer] = layer_data.InboundNodes.Count - ( _should_skip_first_node(layer) ? 1 : 0) ;
var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data)
{
if (!unprocessed_nodes.ContainsKey(layer))
unprocessed_nodes[layer] = node_data;
unprocessed_nodes[layer] = new List<NodeConfig>() { n ode_data } ;
else
unprocessed_nodes.Add(layer, node_data);
unprocessed_nodes[layer].Add( node_data);
}
}
static void process_node(ILayer layer,
NodeConfig node_data,
List< NodeConfig> nodes _data,
Dictionary<string, ILayer> created_layers,
Dictionary<ILayer, int> node_count_by_layer,
Dictionary<(string, int), int> node_index_map)
{
var input_tensors = new List<Tensor>();
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;
var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
for (int i = 0; i < nodes_data.Count; i++)
{
var node_data = nodes_data[i];
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;
var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
}
var output_tensors = layer.Apply(input_tensors);