using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { /// /// A `Functional` model is a `Model` defined as a directed graph of layers. /// public partial class Functional : Model { List _output_layers; List _input_layers; List _input_coordinates; List _output_coordinates; public string[] NetworkNodes { get; set; } Dictionary tensor_usage_count; /// /// Dictionary of layer dependencies to be included in the checkpoint. /// public IDictionary LayerCheckpointDependencies { get { int weight_layer_index = 0; Dictionary dependencies = new(); for(int i = 0; i < Layers.Count; i++) { var layer = Layers[i]; var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); if(weights.Count > 0) { dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; weight_layer_index++; } dependencies[$"layer-{i}"] = layer; } return dependencies; } } public Functional(Tensors inputs, Tensors outputs, string name = null) : base(new ModelArgs { Name = name, Inputs = inputs, Outputs = outputs }) { Initialize(inputs, outputs, name); } internal void Initialize(Tensors inputs, Tensors outputs, string name = null) { _input_layers = new List(); _output_layers = new List(); _input_coordinates = new List(); _output_coordinates = new List(); tensor_usage_count = new Dictionary(); if (this is Sequential) return; _init_graph_network(inputs, outputs); } protected void _init_graph_network(Tensors inputs, Tensors outputs) { _is_graph_network = true; this.inputs = inputs; this.outputs = outputs; built = true; if(inputs.Length > 0) { _buildInputShape = inputs.shape; } else { _buildInputShape = new TensorShapeConfig(); } if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); // Build self._output_layers: foreach (var x in outputs) { var (layer, node_index, tensor_index) = x.KerasHistory; _output_layers.append(layer); _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); } // Build self._input_layers: foreach (var x in inputs) { var (layer, node_index, tensor_index) = x.KerasHistory; _input_layers.append(layer); _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); } // Keep track of the network's nodes and layers. (NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs); // Build self.input_names and self.output_names. _set_output_names(); ComputeTensorUsageCount(); } /// /// Assigns unique names to the Network's outputs. /// void _set_output_names() { var uniquified = new List(); var output_names = new List(); var prefix_count = new Dictionary(); foreach (var layer in _output_layers) { var proposal = layer.Name; while (output_names.Contains(proposal)) { var existing_count = prefix_count.Get(layer.Name, 1); proposal = $"{layer.Name}_{existing_count}"; prefix_count[layer.Name] = existing_count + 1; } output_names.add(proposal); uniquified.append(proposal); } this.output_names = uniquified.ToArray(); } void ComputeTensorUsageCount() { var available_tensors = inputs.Select(x => x.Id).ToList(); var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray(); foreach (var depth in depth_keys) { foreach (var node in NodesByDepth[depth]) { var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray(); if (input_tensors.issubset(available_tensors)) { foreach (var tensor in node.KerasInputs) { if (!tensor_usage_count.ContainsKey(tensor.Id)) tensor_usage_count[tensor.Id] = 0; tensor_usage_count[tensor.Id] += 1; } foreach (var output_tensor in node.Outputs) available_tensors.Add(output_tensor.Id); } } } foreach (var tensor in outputs) { if (!tensor_usage_count.ContainsKey(tensor.Id)) tensor_usage_count[tensor.Id] = 0; tensor_usage_count[tensor.Id] += 1; } } /// /// Validates a network's topology and gather its layers and nodes. /// /// /// (string[], Dictionary>, List, Dictionary>) MapGraphNetwork(Tensors inputs, Tensors outputs) { var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); var network_nodes = nodes_in_decreasing_depth .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) .ToArray(); var nodes_depths = new Dictionary(); var layers_depths = new Dictionary(); nodes_in_decreasing_depth.Reverse(); foreach (var node in nodes_in_decreasing_depth) { // If the depth is not set, the node has no outbound nodes (depth 0). int depth = nodes_depths.SetDefault(node, 0); // Update the depth of the corresponding layer int previous_depth = layers_depths.Get(node.Layer, 0); // If we've seen this layer before at a higher depth, // we should use that depth instead of the node depth. // This is necessary for shared layers that have inputs at different // depth levels in the graph. depth = Math.Max(depth, previous_depth); layers_depths[node.Layer] = depth; nodes_depths[node] = depth; // Update the depth of inbound nodes. // The "depth" of a node is the max of the depths // of all nodes it is connected to + 1. foreach (var node_dep in node.ParentNodes) { previous_depth = nodes_depths.Get(node_dep, 0); nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); } } // Handle inputs that are not connected to outputs. // We do not error out here because the inputs may be used to compute losses // and metrics. foreach (var input_t in inputs) { var (input_layer, _, _) = input_t.KerasHistory; if (!layers_depths.ContainsKey(input_layer)) { layers_depths[input_layer] = 0; layer_indices[input_layer] = -1; nodes_depths[input_layer.InboundNodes[0]] = 0; network_nodes.add(MakeNodeKey(input_layer.Name, 0)); } } // Build a dict {depth: list of nodes with this depth} var nodes_by_depth = new Dictionary>(); foreach (var (node, depth) in enumerate(nodes_depths)) { if (!nodes_by_depth.ContainsKey(depth)) nodes_by_depth[depth] = new List(); nodes_by_depth[depth].append(node); } var layers_by_depth = new Dictionary>(); foreach (var (layer, depth) in enumerate(layers_depths)) { if (!layers_by_depth.ContainsKey(depth)) layers_by_depth[depth] = new List(); layers_by_depth[depth].append(layer); } // Get sorted list of layer depths. var depth_keys = layers_by_depth.Keys.OrderBy(x => x).Reverse(); // Set self.layers ordered by depth. var layers = new List(); foreach (var depth in depth_keys) { var layers_for_depth = layers_by_depth[depth]; // Network.layers needs to have a deterministic order: // here we order them by traversal order. layers_for_depth = layers_for_depth.OrderBy(x => layer_indices[x]).ToList(); layers.AddRange(layers_for_depth); } // Get sorted list of node depths. depth_keys = nodes_by_depth.Keys.OrderBy(x => x).Reverse(); return (network_nodes, nodes_by_depth, layers, layers_by_depth); } string MakeNodeKey(string layer_name, int node_index) => $"{layer_name}_ib-{node_index}"; /// /// This method topologically sorts nodes in order from inputs to outputs. /// /// (List, Dictionary) BuildMap(Tensors outputs) { var finished_nodes = new List(); var nodes_in_progress = new List(); var nodes_in_decreasing_depth = new List(); var layer_indices = new Dictionary(); foreach (var output in outputs) BuildMapHelper(output, finished_nodes, nodes_in_progress, nodes_in_decreasing_depth, layer_indices); return (nodes_in_decreasing_depth, layer_indices); } void BuildMapHelper(Tensor tensor, List finished_nodes, List nodes_in_progress, List nodes_in_decreasing_depth, Dictionary layer_indices) { var (layer, node_index, _) = tensor.KerasHistory; var node = layer.InboundNodes[node_index] as Node; // Don't repeat work for shared subgraphs if (finished_nodes.Contains(node)) return; // Prevent cycles. if (nodes_in_progress.Contains(node)) throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); // Store the traversal order for layer sorting. if (!layer_indices.ContainsKey(layer)) layer_indices[layer] = layer_indices.Count; // Propagate to all previous tensors connected to this node. nodes_in_progress.Add(node); if (!node.is_input) { foreach (var k_tensor in node.KerasInputs) { BuildMapHelper(k_tensor, finished_nodes, nodes_in_progress, nodes_in_decreasing_depth, layer_indices); } } finished_nodes.Add(node); nodes_in_progress.Remove(node); nodes_in_decreasing_depth.append(node); } protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { var tensor_dict = new Dictionary>(); // map input values foreach (var (x, y) in zip(this.inputs, inputs)) { tensor_dict[x.Id] = new Queue(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y)); } var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray(); foreach (var depth in depth_keys) { var nodes = NodesByDepth[depth]; foreach (Node node in nodes) { // Input tensors already exist. if (node.is_input) continue; var layer_inputs = node.MapArguments(tensor_dict); tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); var outputs = node.Layer.Apply(layer_inputs, training: training ?? false); foreach (var output in outputs.Where(x => x != null)) tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}"); // Update tensor_dict for next or later input foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs)) tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); } } var output_tensors = new Tensors(); foreach (var x in outputs) output_tensors.Add(tensor_dict[x.Id].Dequeue()); return output_tensors; } public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); } protected override void _init_set_name(string name, bool zero_based = true) { if (string.IsNullOrEmpty(name)) { string class_name = GetType().Name; if (this.GetType() == typeof(Functional)) { class_name = "Model"; } this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); } else { this.name = name; } } } }