using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
public static Functional from_config(ModelConfig config)
{
var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config);
var model = new Functional(input_tensors, output_tensors, name: config.Name);
model.connect_ancillary_layers(created_layers);
return model;
}
///
/// Reconstructs graph from config object.
///
///
///
static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config)
{
// Layer instances created during the graph reconstruction process.
var created_layers = new Dictionary();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary();
var unprocessed_nodes = new Dictionary();
// First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
// Then we process nodes in order of layer depth.
// Nodes that cannot yet be processed (if the inbound node
// does not yet exist) are re-enqueued, and the process
// is repeated until all nodes are processed.
while (unprocessed_nodes.Count > 0)
{
foreach(var layer_data in config.Layers)
{
var layer = created_layers[layer_data.Name];
if (unprocessed_nodes.ContainsKey(layer))
{
var node_data = unprocessed_nodes[layer];
// foreach (var node_data in unprocessed_nodes[layer])
{
process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map);
unprocessed_nodes.Remove(layer);
}
}
}
}
var input_tensors = new List();
foreach (var layer_data in config.InputLayers)
{
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
var layer = created_layers[layer_name];
var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
input_tensors.append(layer_output_tensors[tensor_index]);
}
var output_tensors = new List();
foreach (var layer_data in config.OutputLayers)
{
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
var layer = created_layers[layer_name];
var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
output_tensors.append(layer_output_tensors[tensor_index]);
}
return (input_tensors, output_tensors, created_layers);
}
static void process_layer(Dictionary created_layers,
LayerConfig layer_data,
Dictionary unprocessed_nodes,
Dictionary node_count_by_layer)
{
ILayer layer = null;
var layer_name = layer_data.Name;
if (created_layers.ContainsKey(layer_name))
layer = created_layers[layer_name];
else
{
layer = layer_data.ClassName switch
{
"InputLayer" => InputLayer.from_config(layer_data.Config),
"Dense" => Dense.from_config(layer_data.Config),
_ => throw new NotImplementedException("")
};
created_layers[layer_name] = layer;
}
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data)
{
if (!unprocessed_nodes.ContainsKey(layer))
unprocessed_nodes[layer] = node_data;
else
unprocessed_nodes.Add(layer, node_data);
}
}
static void process_node(ILayer layer,
NodeConfig node_data,
Dictionary created_layers,
Dictionary node_count_by_layer,
Dictionary<(string, int), int> node_index_map)
{
var input_tensors = new List();
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;
var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
var output_tensors = layer.Apply(input_tensors);
// Update node index map.
var output_index = output_tensors[0].KerasHistory.NodeIndex;
node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index;
node_count_by_layer[layer] += 1;
}
static bool _should_skip_first_node(ILayer layer)
{
return layer is Functional && layer.Layers[0] is InputLayer;
}
}
}