using System; using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { public partial class Functional { public static Functional from_config(ModelConfig config) { var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); var model = new Functional(input_tensors, output_tensors, name: config.Name); model.connect_ancillary_layers(created_layers); return model; } /// /// Reconstructs graph from config object. /// /// /// static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config) { // Layer instances created during the graph reconstruction process. var created_layers = new Dictionary(); var node_index_map = new Dictionary<(string, int), int>(); var node_count_by_layer = new Dictionary(); var unprocessed_nodes = new Dictionary(); // First, we create all layers and enqueue nodes to be processed foreach (var layer_data in config.Layers) process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); // Then we process nodes in order of layer depth. // Nodes that cannot yet be processed (if the inbound node // does not yet exist) are re-enqueued, and the process // is repeated until all nodes are processed. while (unprocessed_nodes.Count > 0) { foreach(var layer_data in config.Layers) { var layer = created_layers[layer_data.Name]; if (unprocessed_nodes.ContainsKey(layer)) { var node_data = unprocessed_nodes[layer]; // foreach (var node_data in unprocessed_nodes[layer]) { process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map); unprocessed_nodes.Remove(layer); } } } } var input_tensors = new List(); foreach (var layer_data in config.InputLayers) { var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); var layer = created_layers[layer_name]; var layer_output_tensors = layer.InboundNodes[node_index].Outputs; input_tensors.append(layer_output_tensors[tensor_index]); } var output_tensors = new List(); foreach (var layer_data in config.OutputLayers) { var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); var layer = created_layers[layer_name]; var layer_output_tensors = layer.InboundNodes[node_index].Outputs; output_tensors.append(layer_output_tensors[tensor_index]); } return (input_tensors, output_tensors, created_layers); } static void process_layer(Dictionary created_layers, LayerConfig layer_data, Dictionary unprocessed_nodes, Dictionary node_count_by_layer) { ILayer layer = null; var layer_name = layer_data.Name; if (created_layers.ContainsKey(layer_name)) layer = created_layers[layer_name]; else { layer = layer_data.ClassName switch { "InputLayer" => InputLayer.from_config(layer_data.Config), "Dense" => Dense.from_config(layer_data.Config), _ => throw new NotImplementedException("") }; created_layers[layer_name] = layer; } node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0; var inbound_nodes_data = layer_data.InboundNodes; foreach (var node_data in inbound_nodes_data) { if (!unprocessed_nodes.ContainsKey(layer)) unprocessed_nodes[layer] = node_data; else unprocessed_nodes.Add(layer, node_data); } } static void process_node(ILayer layer, NodeConfig node_data, Dictionary created_layers, Dictionary node_count_by_layer, Dictionary<(string, int), int> node_index_map) { var input_tensors = new List(); var inbound_layer_name = node_data.Name; var inbound_node_index = node_data.NodeIndex; var inbound_tensor_index = node_data.TensorIndex; var inbound_layer = created_layers[inbound_layer_name]; var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; input_tensors.Add(inbound_node.Outputs[inbound_node_index]); var output_tensors = layer.Apply(input_tensors); // Update node index map. var output_index = output_tensors[0].KerasHistory.NodeIndex; node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index; node_count_by_layer[layer] += 1; } static bool _should_skip_first_node(ILayer layer) { return layer is Functional && layer.Layers[0] is InputLayer; } } }