You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Functional.FromConfig.cs 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Keras.Layers;
  6. using Tensorflow.Keras.Saving;
  7. using Tensorflow.Keras.Utils;
  8. using static Tensorflow.Binding;
  9. namespace Tensorflow.Keras.Engine
  10. {
  11. public partial class Functional
  12. {
  13. public static Functional from_config(ModelConfig config)
  14. {
  15. var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config);
  16. var model = new Functional(input_tensors, output_tensors, name: config.Name);
  17. model.connect_ancillary_layers(created_layers);
  18. return model;
  19. }
  20. /// <summary>
  21. /// Reconstructs graph from config object.
  22. /// </summary>
  23. /// <param name="config"></param>
  24. /// <returns></returns>
  25. static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
  26. {
  27. // Layer instances created during the graph reconstruction process.
  28. var created_layers = new Dictionary<string, ILayer>();
  29. var node_index_map = new Dictionary<(string, int), int>();
  30. var node_count_by_layer = new Dictionary<ILayer, int>();
  31. var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
  32. // First, we create all layers and enqueue nodes to be processed
  33. foreach (var layer_data in config.Layers)
  34. process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
  35. // Then we process nodes in order of layer depth.
  36. // Nodes that cannot yet be processed (if the inbound node
  37. // does not yet exist) are re-enqueued, and the process
  38. // is repeated until all nodes are processed.
  39. while (unprocessed_nodes.Count > 0)
  40. {
  41. foreach(var layer_data in config.Layers)
  42. {
  43. var layer = created_layers[layer_data.Name];
  44. if (unprocessed_nodes.ContainsKey(layer))
  45. {
  46. var node_data = unprocessed_nodes[layer];
  47. // foreach (var node_data in unprocessed_nodes[layer])
  48. {
  49. process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map);
  50. unprocessed_nodes.Remove(layer);
  51. }
  52. }
  53. }
  54. }
  55. var input_tensors = new List<Tensor>();
  56. foreach (var layer_data in config.InputLayers)
  57. {
  58. var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
  59. var layer = created_layers[layer_name];
  60. var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
  61. input_tensors.append(layer_output_tensors[tensor_index]);
  62. }
  63. var output_tensors = new List<Tensor>();
  64. foreach (var layer_data in config.OutputLayers)
  65. {
  66. var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
  67. var layer = created_layers[layer_name];
  68. var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
  69. output_tensors.append(layer_output_tensors[tensor_index]);
  70. }
  71. return (input_tensors, output_tensors, created_layers);
  72. }
  73. static void process_layer(Dictionary<string, ILayer> created_layers,
  74. LayerConfig layer_data,
  75. Dictionary<ILayer, NodeConfig> unprocessed_nodes,
  76. Dictionary<ILayer, int> node_count_by_layer)
  77. {
  78. ILayer layer = null;
  79. var layer_name = layer_data.Name;
  80. if (created_layers.ContainsKey(layer_name))
  81. layer = created_layers[layer_name];
  82. else
  83. {
  84. layer = layer_data.ClassName switch
  85. {
  86. "InputLayer" => InputLayer.from_config(layer_data.Config),
  87. "Dense" => Dense.from_config(layer_data.Config),
  88. _ => throw new NotImplementedException("")
  89. };
  90. created_layers[layer_name] = layer;
  91. }
  92. node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
  93. var inbound_nodes_data = layer_data.InboundNodes;
  94. foreach (var node_data in inbound_nodes_data)
  95. {
  96. if (!unprocessed_nodes.ContainsKey(layer))
  97. unprocessed_nodes[layer] = node_data;
  98. else
  99. unprocessed_nodes.Add(layer, node_data);
  100. }
  101. }
  102. static void process_node(ILayer layer,
  103. NodeConfig node_data,
  104. Dictionary<string, ILayer> created_layers,
  105. Dictionary<ILayer, int> node_count_by_layer,
  106. Dictionary<(string, int), int> node_index_map)
  107. {
  108. var input_tensors = new List<Tensor>();
  109. var inbound_layer_name = node_data.Name;
  110. var inbound_node_index = node_data.NodeIndex;
  111. var inbound_tensor_index = node_data.TensorIndex;
  112. var inbound_layer = created_layers[inbound_layer_name];
  113. var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
  114. input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
  115. var output_tensors = layer.Apply(input_tensors);
  116. // Update node index map.
  117. var output_index = output_tensors[0].KerasHistory.NodeIndex;
  118. node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index;
  119. node_count_by_layer[layer] += 1;
  120. }
  121. static bool _should_skip_first_node(ILayer layer)
  122. {
  123. return layer is Functional && layer.Layers[0] is InputLayer;
  124. }
  125. }
  126. }