You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Functional.GetConfig.cs 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Keras.Layers;
  6. using Tensorflow.Keras.Saving;
  7. using Tensorflow.Keras.Utils;
  8. using static Tensorflow.Binding;
  9. namespace Tensorflow.Keras.Engine
  10. {
  11. public partial class Functional
  12. {
  13. public ModelConfig get_config()
  14. {
  15. return get_network_config();
  16. }
  17. /// <summary>
  18. /// Builds the config, which consists of the node graph and serialized layers.
  19. /// </summary>
  20. ModelConfig get_network_config()
  21. {
  22. var config = new ModelConfig
  23. {
  24. Name = name
  25. };
  26. var node_conversion_map = new Dictionary<string, int>();
  27. foreach (var layer in _layers)
  28. {
  29. var kept_nodes = _should_skip_first_node(layer) ? 1 : 0;
  30. foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
  31. {
  32. var node_key = _make_node_key(layer.Name, original_node_index);
  33. if (NetworkNodes.Contains(node_key))
  34. {
  35. node_conversion_map[node_key] = kept_nodes;
  36. kept_nodes += 1;
  37. }
  38. }
  39. }
  40. var layer_configs = new List<LayerConfig>();
  41. foreach (var layer in _layers)
  42. {
  43. var filtered_inbound_nodes = new List<NodeConfig>();
  44. foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
  45. {
  46. var node_key = _make_node_key(layer.Name, original_node_index);
  47. if (NetworkNodes.Contains(node_key) && !node.is_input)
  48. {
  49. var node_data = node.serialize(_make_node_key, node_conversion_map);
  50. filtered_inbound_nodes.append(node_data);
  51. }
  52. }
  53. var layer_config = generic_utils.serialize_keras_object(layer);
  54. layer_config.Name = layer.Name;
  55. layer_config.InboundNodes = filtered_inbound_nodes;
  56. layer_configs.Add(layer_config);
  57. }
  58. config.Layers = layer_configs;
  59. // Gather info about inputs and outputs.
  60. var model_inputs = new List<NodeConfig>();
  61. foreach (var i in range(_input_layers.Count))
  62. {
  63. var (layer, node_index, tensor_index) = _input_coordinates[i];
  64. var node_key = _make_node_key(layer.Name, node_index);
  65. if (!NetworkNodes.Contains(node_key))
  66. continue;
  67. var new_node_index = node_conversion_map[node_key];
  68. model_inputs.append(new NodeConfig
  69. {
  70. Name = layer.Name,
  71. NodeIndex = new_node_index,
  72. TensorIndex = tensor_index
  73. });
  74. }
  75. config.InputLayers = model_inputs;
  76. var model_outputs = new List<NodeConfig>();
  77. foreach (var i in range(_output_layers.Count))
  78. {
  79. var (layer, node_index, tensor_index) = _output_coordinates[i];
  80. var node_key = _make_node_key(layer.Name, node_index);
  81. if (!NetworkNodes.Contains(node_key))
  82. continue;
  83. var new_node_index = node_conversion_map[node_key];
  84. model_outputs.append(new NodeConfig
  85. {
  86. Name = layer.Name,
  87. NodeIndex = new_node_index,
  88. TensorIndex = tensor_index
  89. });
  90. }
  91. config.OutputLayers = model_outputs;
  92. return config;
  93. }
  94. string _make_node_key(string layer_name, int node_index)
  95. => $"{layer_name}_ib-{node_index}";
  96. }
  97. }