You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Functional.cs 14 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.ArgsDefinition;
  5. using Tensorflow.Keras.Utils;
  6. using static Tensorflow.Binding;
  7. namespace Tensorflow.Keras.Engine
  8. {
  9. /// <summary>
  10. /// A `Functional` model is a `Model` defined as a directed graph of layers.
  11. /// </summary>
  12. public partial class Functional : Model
  13. {
  14. TensorShape _build_input_shape;
  15. bool _compute_output_and_mask_jointly;
  16. bool _expects_training_arg;
  17. bool _expects_mask_arg;
  18. bool _autocast;
  19. List<ILayer> _output_layers;
  20. List<ILayer> _input_layers;
  21. List<KerasHistory> _input_coordinates;
  22. List<KerasHistory> _output_coordinates;
  23. public string[] NetworkNodes { get; set; }
  24. Dictionary<long, int> tensor_usage_count;
  25. public Functional(Tensors inputs, Tensors outputs, string name = null)
  26. : base(new ModelArgs
  27. {
  28. Name = name,
  29. Inputs = inputs,
  30. Outputs = outputs
  31. })
  32. {
  33. _input_layers = new List<ILayer>();
  34. _output_layers = new List<ILayer>();
  35. _input_coordinates = new List<KerasHistory>();
  36. _output_coordinates = new List<KerasHistory>();
  37. tensor_usage_count = new Dictionary<long, int>();
  38. if (this is Sequential)
  39. return;
  40. _init_graph_network(inputs, outputs);
  41. }
  42. protected void _init_graph_network(Tensors inputs, Tensors outputs)
  43. {
  44. _is_graph_network = true;
  45. this.inputs = inputs;
  46. this.outputs = outputs;
  47. built = true;
  48. _build_input_shape = inputs.shape;
  49. _compute_output_and_mask_jointly = true;
  50. _expects_training_arg = true;
  51. _expects_mask_arg = true;
  52. // A graph network does not autocast inputs, as its layers will cast them instead.
  53. _autocast = false;
  54. if (outputs.Any(x => x.KerasHistory == null))
  55. base_layer_utils.create_keras_history(outputs);
  56. // Build self._output_layers:
  57. foreach (var x in outputs)
  58. {
  59. var (layer, node_index, tensor_index) = x.KerasHistory;
  60. _output_layers.append(layer);
  61. _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
  62. }
  63. // Build self._input_layers:
  64. foreach (var x in inputs)
  65. {
  66. var (layer, node_index, tensor_index) = x.KerasHistory;
  67. _input_layers.append(layer);
  68. _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
  69. }
  70. // Keep track of the network's nodes and layers.
  71. var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs);
  72. NetworkNodes = nodes;
  73. NodesByDepth = nodes_by_depth;
  74. _layers = layers;
  75. // Build self.input_names and self.output_names.
  76. _set_output_names();
  77. ComputeTensorUsageCount();
  78. }
  79. /// <summary>
  80. /// Assigns unique names to the Network's outputs.
  81. /// </summary>
  82. void _set_output_names()
  83. {
  84. var uniquified = new List<string>();
  85. var output_names = new List<string>();
  86. var prefix_count = new Dictionary<string, int>();
  87. foreach (var layer in _output_layers)
  88. {
  89. var proposal = layer.Name;
  90. while (output_names.Contains(proposal))
  91. {
  92. var existing_count = prefix_count.Get(layer.Name, 1);
  93. proposal = $"{layer.Name}_{existing_count}";
  94. prefix_count[layer.Name] = existing_count + 1;
  95. }
  96. output_names.add(proposal);
  97. uniquified.append(proposal);
  98. }
  99. this.output_names = uniquified.ToArray();
  100. }
  101. void ComputeTensorUsageCount()
  102. {
  103. var available_tensors = inputs.Select(x => x.Id).ToList();
  104. var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
  105. foreach (var depth in depth_keys)
  106. {
  107. foreach (var node in NodesByDepth[depth])
  108. {
  109. var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray();
  110. if (input_tensors.issubset(available_tensors))
  111. {
  112. foreach (var tensor in node.KerasInputs)
  113. {
  114. if (!tensor_usage_count.ContainsKey(tensor.Id))
  115. tensor_usage_count[tensor.Id] = 0;
  116. tensor_usage_count[tensor.Id] += 1;
  117. }
  118. foreach (var output_tensor in node.Outputs)
  119. available_tensors.Add(output_tensor.Id);
  120. }
  121. }
  122. }
  123. foreach (var tensor in outputs)
  124. {
  125. if (!tensor_usage_count.ContainsKey(tensor.Id))
  126. tensor_usage_count[tensor.Id] = 0;
  127. tensor_usage_count[tensor.Id] += 1;
  128. }
  129. }
  130. /// <summary>
  131. /// Validates a network's topology and gather its layers and nodes.
  132. /// </summary>
  133. /// <param name="inputs"></param>
  134. /// <param name="outputs"></param>
  135. (string[], Dictionary<int, List<INode>>, List<ILayer>, Dictionary<int, List<ILayer>>) MapGraphNetwork(Tensors inputs, Tensors outputs)
  136. {
  137. var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs);
  138. var network_nodes = nodes_in_decreasing_depth
  139. .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node)))
  140. .ToArray();
  141. var nodes_depths = new Dictionary<INode, int>();
  142. var layers_depths = new Dictionary<ILayer, int>();
  143. nodes_in_decreasing_depth.Reverse();
  144. foreach (var node in nodes_in_decreasing_depth)
  145. {
  146. // If the depth is not set, the node has no outbound nodes (depth 0).
  147. int depth = nodes_depths.SetDefault(node, 0);
  148. // Update the depth of the corresponding layer
  149. int previous_depth = layers_depths.Get(node.Layer, 0);
  150. // If we've seen this layer before at a higher depth,
  151. // we should use that depth instead of the node depth.
  152. // This is necessary for shared layers that have inputs at different
  153. // depth levels in the graph.
  154. depth = Math.Max(depth, previous_depth);
  155. layers_depths[node.Layer] = depth;
  156. nodes_depths[node] = depth;
  157. // Update the depth of inbound nodes.
  158. // The "depth" of a node is the max of the depths
  159. // of all nodes it is connected to + 1.
  160. foreach (var node_dep in node.ParentNodes)
  161. {
  162. previous_depth = nodes_depths.Get(node_dep, 0);
  163. nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth);
  164. }
  165. }
  166. // Handle inputs that are not connected to outputs.
  167. // We do not error out here because the inputs may be used to compute losses
  168. // and metrics.
  169. foreach (var input_t in inputs)
  170. {
  171. var (input_layer, _, _) = input_t.KerasHistory;
  172. if (!layers_depths.ContainsKey(input_layer))
  173. {
  174. layers_depths[input_layer] = 0;
  175. layer_indices[input_layer] = -1;
  176. nodes_depths[input_layer.InboundNodes[0]] = 0;
  177. network_nodes.add(MakeNodeKey(input_layer.Name, 0));
  178. }
  179. }
  180. // Build a dict {depth: list of nodes with this depth}
  181. var nodes_by_depth = new Dictionary<int, List<INode>>();
  182. foreach (var (node, depth) in enumerate(nodes_depths))
  183. {
  184. if (!nodes_by_depth.ContainsKey(depth))
  185. nodes_by_depth[depth] = new List<INode>();
  186. nodes_by_depth[depth].append(node);
  187. }
  188. var layers_by_depth = new Dictionary<int, List<ILayer>>();
  189. foreach (var (layer, depth) in enumerate(layers_depths))
  190. {
  191. if (!layers_by_depth.ContainsKey(depth))
  192. layers_by_depth[depth] = new List<ILayer>();
  193. layers_by_depth[depth].append(layer);
  194. }
  195. // Get sorted list of layer depths.
  196. var depth_keys = layers_by_depth.Keys.OrderBy(x => x).Reverse();
  197. // Set self.layers ordered by depth.
  198. var layers = new List<ILayer>();
  199. foreach (var depth in depth_keys)
  200. {
  201. var layers_for_depth = layers_by_depth[depth];
  202. // Network.layers needs to have a deterministic order:
  203. // here we order them by traversal order.
  204. layers_for_depth = layers_for_depth.OrderBy(x => layer_indices[x]).ToList();
  205. layers.AddRange(layers_for_depth);
  206. }
  207. // Get sorted list of node depths.
  208. depth_keys = nodes_by_depth.Keys.OrderBy(x => x).Reverse();
  209. return (network_nodes, nodes_by_depth, layers, layers_by_depth);
  210. }
  211. string MakeNodeKey(string layer_name, int node_index)
  212. => $"{layer_name}_ib-{node_index}";
  213. /// <summary>
  214. /// This method topologically sorts nodes in order from inputs to outputs.
  215. /// </summary>
  216. /// <param name="outputs"></param>
  217. (List<INode>, Dictionary<ILayer, int>) BuildMap(Tensors outputs)
  218. {
  219. var finished_nodes = new List<INode>();
  220. var nodes_in_progress = new List<INode>();
  221. var nodes_in_decreasing_depth = new List<INode>();
  222. var layer_indices = new Dictionary<ILayer, int>();
  223. foreach (var output in outputs)
  224. BuildMapHelper(output,
  225. finished_nodes,
  226. nodes_in_progress,
  227. nodes_in_decreasing_depth,
  228. layer_indices);
  229. return (nodes_in_decreasing_depth, layer_indices);
  230. }
  231. void BuildMapHelper(Tensor tensor,
  232. List<INode> finished_nodes,
  233. List<INode> nodes_in_progress,
  234. List<INode> nodes_in_decreasing_depth,
  235. Dictionary<ILayer, int> layer_indices)
  236. {
  237. var (layer, node_index, _) = tensor.KerasHistory;
  238. var node = layer.InboundNodes[node_index] as Node;
  239. // Don't repeat work for shared subgraphs
  240. if (finished_nodes.Contains(node))
  241. return;
  242. // Prevent cycles.
  243. if (nodes_in_progress.Contains(node))
  244. throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle.");
  245. // Store the traversal order for layer sorting.
  246. if (!layer_indices.ContainsKey(layer))
  247. layer_indices[layer] = layer_indices.Count;
  248. // Propagate to all previous tensors connected to this node.
  249. nodes_in_progress.Add(node);
  250. if (!node.is_input)
  251. {
  252. foreach (var k_tensor in node.KerasInputs)
  253. {
  254. BuildMapHelper(k_tensor,
  255. finished_nodes,
  256. nodes_in_progress,
  257. nodes_in_decreasing_depth,
  258. layer_indices);
  259. }
  260. }
  261. finished_nodes.Add(node);
  262. nodes_in_progress.Remove(node);
  263. nodes_in_decreasing_depth.append(node);
  264. }
  265. protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
  266. {
  267. return run_internal_graph(inputs, is_training);
  268. }
  269. Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null)
  270. {
  271. if (mask == null)
  272. {
  273. Tensor[] masks = new Tensor[inputs.Count()];
  274. foreach (var (i, input_t) in enumerate(inputs))
  275. input_t.KerasMask = masks[i];
  276. }
  277. var tensor_dict = new Dictionary<long, Queue<Tensor>>();
  278. foreach (var (x, y) in zip(this.inputs, inputs))
  279. {
  280. var y1 = conform_to_reference_input(y, x);
  281. tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1));
  282. }
  283. var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
  284. foreach (var depth in depth_keys)
  285. {
  286. var nodes = NodesByDepth[depth];
  287. foreach (Node node in nodes)
  288. {
  289. // Input tensors already exist.
  290. if (node.is_input)
  291. continue;
  292. var layer_inputs = node.MapArguments(tensor_dict);
  293. tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}");
  294. var outputs = node.Layer.Apply(layer_inputs, is_training: training);
  295. foreach (var output in outputs.Where(x => x != null))
  296. tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
  297. // Update tensor_dict for next input
  298. foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
  299. tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));
  300. }
  301. }
  302. var output_tensors = new Tensors();
  303. foreach (var x in outputs)
  304. output_tensors.Add(tensor_dict[x.Id].Dequeue());
  305. return output_tensors;
  306. }
  307. Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input)
  308. {
  309. return tensor;
  310. }
  311. }
  312. }