You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Functional.cs 16 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Common.Types;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Saving.SavedModel;
  7. using Tensorflow.Keras.Utils;
  8. using Tensorflow.Train;
  9. using static Tensorflow.Binding;
  10. namespace Tensorflow.Keras.Engine
  11. {
  12. /// <summary>
  13. /// A `Functional` model is a `Model` defined as a directed graph of layers.
  14. /// </summary>
  15. public partial class Functional : Model
  16. {
  17. List<ILayer> _output_layers;
  18. List<ILayer> _input_layers;
  19. List<KerasHistory> _input_coordinates;
  20. List<KerasHistory> _output_coordinates;
  21. public string[] NetworkNodes { get; set; }
  22. Dictionary<long, int> tensor_usage_count;
  23. /// <summary>
  24. /// Dictionary of layer dependencies to be included in the checkpoint.
  25. /// </summary>
  26. public IDictionary<string, ILayer> LayerCheckpointDependencies
  27. {
  28. get
  29. {
  30. int weight_layer_index = 0;
  31. Dictionary<string, ILayer> dependencies = new();
  32. for(int i = 0; i < Layers.Count; i++)
  33. {
  34. var layer = Layers[i];
  35. var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList();
  36. if(weights.Count > 0)
  37. {
  38. dependencies[$"layer_with_weights-{weight_layer_index}"] = layer;
  39. weight_layer_index++;
  40. }
  41. dependencies[$"layer-{i}"] = layer;
  42. }
  43. return dependencies;
  44. }
  45. }
  46. public Functional(Tensors inputs, Tensors outputs, string name = null)
  47. : base(new ModelArgs
  48. {
  49. Name = name,
  50. Inputs = inputs,
  51. Outputs = outputs
  52. })
  53. {
  54. Initialize(inputs, outputs, name);
  55. }
  56. internal void Initialize(Tensors inputs, Tensors outputs, string name = null)
  57. {
  58. _input_layers = new List<ILayer>();
  59. _output_layers = new List<ILayer>();
  60. _input_coordinates = new List<KerasHistory>();
  61. _output_coordinates = new List<KerasHistory>();
  62. tensor_usage_count = new Dictionary<long, int>();
  63. if (this is Sequential)
  64. return;
  65. _init_graph_network(inputs, outputs);
  66. }
  67. protected void _init_graph_network(Tensors inputs, Tensors outputs)
  68. {
  69. _is_graph_network = true;
  70. this.inputs = inputs;
  71. this.outputs = outputs;
  72. built = true;
  73. if(inputs.Length > 0)
  74. {
  75. _buildInputShape = inputs.shape;
  76. }
  77. else
  78. {
  79. _buildInputShape = new TensorShapeConfig();
  80. }
  81. if (outputs.Any(x => x.KerasHistory == null))
  82. base_layer_utils.create_keras_history(outputs);
  83. // Build self._output_layers:
  84. foreach (var x in outputs)
  85. {
  86. var (layer, node_index, tensor_index) = x.KerasHistory;
  87. _output_layers.append(layer);
  88. _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
  89. }
  90. // Build self._input_layers:
  91. foreach (var x in inputs)
  92. {
  93. var (layer, node_index, tensor_index) = x.KerasHistory;
  94. _input_layers.append(layer);
  95. _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
  96. }
  97. // Keep track of the network's nodes and layers.
  98. (NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs);
  99. // Build self.input_names and self.output_names.
  100. _set_output_names();
  101. ComputeTensorUsageCount();
  102. }
  103. /// <summary>
  104. /// Assigns unique names to the Network's outputs.
  105. /// </summary>
  106. void _set_output_names()
  107. {
  108. var uniquified = new List<string>();
  109. var output_names = new List<string>();
  110. var prefix_count = new Dictionary<string, int>();
  111. foreach (var layer in _output_layers)
  112. {
  113. var proposal = layer.Name;
  114. while (output_names.Contains(proposal))
  115. {
  116. var existing_count = prefix_count.Get(layer.Name, 1);
  117. proposal = $"{layer.Name}_{existing_count}";
  118. prefix_count[layer.Name] = existing_count + 1;
  119. }
  120. output_names.add(proposal);
  121. uniquified.append(proposal);
  122. }
  123. this.output_names = uniquified.ToArray();
  124. }
  125. void ComputeTensorUsageCount()
  126. {
  127. var available_tensors = inputs.Select(x => x.Id).ToList();
  128. var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
  129. foreach (var depth in depth_keys)
  130. {
  131. foreach (var node in NodesByDepth[depth])
  132. {
  133. var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray();
  134. if (input_tensors.issubset(available_tensors))
  135. {
  136. foreach (var tensor in node.KerasInputs)
  137. {
  138. if (!tensor_usage_count.ContainsKey(tensor.Id))
  139. tensor_usage_count[tensor.Id] = 0;
  140. tensor_usage_count[tensor.Id] += 1;
  141. }
  142. foreach (var output_tensor in node.Outputs)
  143. available_tensors.Add(output_tensor.Id);
  144. }
  145. }
  146. }
  147. foreach (var tensor in outputs)
  148. {
  149. if (!tensor_usage_count.ContainsKey(tensor.Id))
  150. tensor_usage_count[tensor.Id] = 0;
  151. tensor_usage_count[tensor.Id] += 1;
  152. }
  153. }
  154. /// <summary>
  155. /// Validates a network's topology and gather its layers and nodes.
  156. /// </summary>
  157. /// <param name="inputs"></param>
  158. /// <param name="outputs"></param>
  159. (string[], Dictionary<int, List<INode>>, List<ILayer>, Dictionary<int, List<ILayer>>) MapGraphNetwork(Tensors inputs, Tensors outputs)
  160. {
  161. var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs);
  162. var network_nodes = nodes_in_decreasing_depth
  163. .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node)))
  164. .ToArray();
  165. var nodes_depths = new Dictionary<INode, int>();
  166. var layers_depths = new Dictionary<ILayer, int>();
  167. nodes_in_decreasing_depth.Reverse();
  168. foreach (var node in nodes_in_decreasing_depth)
  169. {
  170. // If the depth is not set, the node has no outbound nodes (depth 0).
  171. int depth = nodes_depths.SetDefault(node, 0);
  172. // Update the depth of the corresponding layer
  173. int previous_depth = layers_depths.Get(node.Layer, 0);
  174. // If we've seen this layer before at a higher depth,
  175. // we should use that depth instead of the node depth.
  176. // This is necessary for shared layers that have inputs at different
  177. // depth levels in the graph.
  178. depth = Math.Max(depth, previous_depth);
  179. layers_depths[node.Layer] = depth;
  180. nodes_depths[node] = depth;
  181. // Update the depth of inbound nodes.
  182. // The "depth" of a node is the max of the depths
  183. // of all nodes it is connected to + 1.
  184. foreach (var node_dep in node.ParentNodes)
  185. {
  186. previous_depth = nodes_depths.Get(node_dep, 0);
  187. nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth);
  188. }
  189. }
  190. // Handle inputs that are not connected to outputs.
  191. // We do not error out here because the inputs may be used to compute losses
  192. // and metrics.
  193. foreach (var input_t in inputs)
  194. {
  195. var (input_layer, _, _) = input_t.KerasHistory;
  196. if (!layers_depths.ContainsKey(input_layer))
  197. {
  198. layers_depths[input_layer] = 0;
  199. layer_indices[input_layer] = -1;
  200. nodes_depths[input_layer.InboundNodes[0]] = 0;
  201. network_nodes.add(MakeNodeKey(input_layer.Name, 0));
  202. }
  203. }
  204. // Build a dict {depth: list of nodes with this depth}
  205. var nodes_by_depth = new Dictionary<int, List<INode>>();
  206. foreach (var (node, depth) in enumerate(nodes_depths))
  207. {
  208. if (!nodes_by_depth.ContainsKey(depth))
  209. nodes_by_depth[depth] = new List<INode>();
  210. nodes_by_depth[depth].append(node);
  211. }
  212. var layers_by_depth = new Dictionary<int, List<ILayer>>();
  213. foreach (var (layer, depth) in enumerate(layers_depths))
  214. {
  215. if (!layers_by_depth.ContainsKey(depth))
  216. layers_by_depth[depth] = new List<ILayer>();
  217. layers_by_depth[depth].append(layer);
  218. }
  219. // Get sorted list of layer depths.
  220. var depth_keys = layers_by_depth.Keys.OrderBy(x => x).Reverse();
  221. // Set self.layers ordered by depth.
  222. var layers = new List<ILayer>();
  223. foreach (var depth in depth_keys)
  224. {
  225. var layers_for_depth = layers_by_depth[depth];
  226. // Network.layers needs to have a deterministic order:
  227. // here we order them by traversal order.
  228. layers_for_depth = layers_for_depth.OrderBy(x => layer_indices[x]).ToList();
  229. layers.AddRange(layers_for_depth);
  230. }
  231. // Get sorted list of node depths.
  232. depth_keys = nodes_by_depth.Keys.OrderBy(x => x).Reverse();
  233. return (network_nodes, nodes_by_depth, layers, layers_by_depth);
  234. }
  235. string MakeNodeKey(string layer_name, int node_index)
  236. => $"{layer_name}_ib-{node_index}";
  237. /// <summary>
  238. /// This method topologically sorts nodes in order from inputs to outputs.
  239. /// </summary>
  240. /// <param name="outputs"></param>
  241. (List<INode>, Dictionary<ILayer, int>) BuildMap(Tensors outputs)
  242. {
  243. var finished_nodes = new List<INode>();
  244. var nodes_in_progress = new List<INode>();
  245. var nodes_in_decreasing_depth = new List<INode>();
  246. var layer_indices = new Dictionary<ILayer, int>();
  247. foreach (var output in outputs)
  248. BuildMapHelper(output,
  249. finished_nodes,
  250. nodes_in_progress,
  251. nodes_in_decreasing_depth,
  252. layer_indices);
  253. return (nodes_in_decreasing_depth, layer_indices);
  254. }
  255. void BuildMapHelper(Tensor tensor,
  256. List<INode> finished_nodes,
  257. List<INode> nodes_in_progress,
  258. List<INode> nodes_in_decreasing_depth,
  259. Dictionary<ILayer, int> layer_indices)
  260. {
  261. var (layer, node_index, _) = tensor.KerasHistory;
  262. var node = layer.InboundNodes[node_index] as Node;
  263. // Don't repeat work for shared subgraphs
  264. if (finished_nodes.Contains(node))
  265. return;
  266. // Prevent cycles.
  267. if (nodes_in_progress.Contains(node))
  268. throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle.");
  269. // Store the traversal order for layer sorting.
  270. if (!layer_indices.ContainsKey(layer))
  271. layer_indices[layer] = layer_indices.Count;
  272. // Propagate to all previous tensors connected to this node.
  273. nodes_in_progress.Add(node);
  274. if (!node.is_input)
  275. {
  276. foreach (var k_tensor in node.KerasInputs)
  277. {
  278. BuildMapHelper(k_tensor,
  279. finished_nodes,
  280. nodes_in_progress,
  281. nodes_in_decreasing_depth,
  282. layer_indices);
  283. }
  284. }
  285. finished_nodes.Add(node);
  286. nodes_in_progress.Remove(node);
  287. nodes_in_decreasing_depth.append(node);
  288. }
  289. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  290. {
  291. var tensor_dict = new Dictionary<long, Queue<Tensor>>();
  292. // map input values
  293. foreach (var (x, y) in zip(this.inputs, inputs))
  294. {
  295. tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y));
  296. }
  297. var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
  298. foreach (var depth in depth_keys)
  299. {
  300. var nodes = NodesByDepth[depth];
  301. foreach (Node node in nodes)
  302. {
  303. // Input tensors already exist.
  304. if (node.is_input)
  305. continue;
  306. var layer_inputs = node.MapArguments(tensor_dict);
  307. tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}");
  308. var outputs = node.Layer.Apply(layer_inputs, training: training ?? false);
  309. foreach (var output in outputs.Where(x => x != null))
  310. tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.shape}");
  311. // Update tensor_dict for next or later input
  312. foreach (var (x_id, y) in zip(node.Outputs.Select(x => x.Id), outputs))
  313. tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));
  314. }
  315. }
  316. var output_tensors = new Tensors();
  317. foreach (var x in outputs)
  318. output_tensors.Add(tensor_dict[x.Id].Dequeue());
  319. return output_tensors;
  320. }
  321. public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
  322. {
  323. return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
  324. .ToDictionary(x => x.Key, x => x.Value);
  325. }
  326. protected override void _init_set_name(string name, bool zero_based = true)
  327. {
  328. if (string.IsNullOrEmpty(name))
  329. {
  330. string class_name = GetType().Name;
  331. if (this.GetType() == typeof(Functional))
  332. {
  333. class_name = "Model";
  334. }
  335. this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based);
  336. }
  337. else
  338. {
  339. this.name = name;
  340. }
  341. }
  342. }
  343. }