From 8b9fca47e98064d16cecdd1f436e114befdad4db Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 15 Nov 2020 11:38:16 -0600 Subject: [PATCH] Add save and restore model from config. --- src/TensorFlowNET.Core/Binding.Util.cs | 6 + .../Eager/EagerRunner.RecordGradient.cs | 2 +- src/TensorFlowNET.Core/Gradients/math_grad.cs | 14 +- src/TensorFlowNET.Core/Keras/Engine/INode.cs | 2 +- .../Keras/Engine/KerasHistory.cs | 2 + .../Keras/Saving/LayerConfig.cs | 2 +- .../Keras/Saving/ModelConfig.cs | 7 +- .../Keras/Saving/NodeConfig.cs | 3 + .../Functional.ConnectAncillaryLayers.cs | 23 +++ .../Engine/Functional.FromConfig.cs | 140 ++++++++++++++++++ .../Engine/Functional.GetConfig.cs | 44 +++++- src/TensorFlowNET.Keras/Engine/Layer.cs | 2 +- .../Engine/Node.Serialize.cs | 16 +- src/TensorFlowNET.Keras/KerasInterface.cs | 3 + src/TensorFlowNET.Keras/Layers/Dense.cs | 5 + src/TensorFlowNET.Keras/Layers/InputLayer.cs | 6 +- src/TensorFlowNET.Keras/Models/ModelsApi.cs | 14 ++ .../Keras/ModelSaveTest.cs | 4 +- 18 files changed, 269 insertions(+), 26 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs create mode 100644 src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs create mode 100644 src/TensorFlowNET.Keras/Models/ModelsApi.cs diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index b3607eb1..0411227f 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -58,6 +58,12 @@ namespace Tensorflow public static void append(this IList list, T element) => list.Insert(list.Count, element); + public static void append(this IList list, IList elements) + { + for (int i = 0; i < elements.Count(); i++) + list.Insert(list.Count, elements[i]); + } + public static T[] concat(this IList list1, IList list2) { var list = new List(); diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index 917e3d1c..91b96230 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -38,7 +38,7 @@ namespace Tensorflow.Eager }*/ } - Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); + // Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); if (!should_record) return should_record; Tensor[] op_outputs; diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 51956746..424493c6 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -761,13 +761,6 @@ namespace Tensorflow.Gradients { sx = array_ops.shape(x); sy = array_ops.shape(y); - - var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); - return new[] - { - (sx, rx, true), - (sy, ry, true) - }; } else { @@ -775,7 +768,12 @@ namespace Tensorflow.Gradients sy = array_ops.shape_internal(y, optimize: false); } - throw new NotImplementedException(""); + var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); + return new[] + { + (sx, rx, true), + (sy, ry, true) + }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/INode.cs b/src/TensorFlowNET.Core/Keras/Engine/INode.cs index dde0f8ea..f1ac2555 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/INode.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/INode.cs @@ -13,6 +13,6 @@ namespace Tensorflow.Keras.Engine INode[] ParentNodes { get; } IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); bool is_input { get; } - NodeConfig serialize(Func make_node_key, Dictionary node_conversion_map); + List serialize(Func make_node_key, Dictionary node_conversion_map); } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs index d218c17e..b827daeb 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -8,7 +8,9 @@ ILayer layer; public ILayer Layer => layer; int node_index; + public int NodeIndex => node_index; int tensor_index; + public int TensorIndex => tensor_index; Tensor tensor; public KerasHistory(ILayer layer, int node_index, int tensor_index, Tensor tensor) diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs index 950b3132..b8b8cab4 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -11,6 +11,6 @@ namespace Tensorflow.Keras.Saving public string Name { get; set; } public string ClassName { get; set; } public LayerArgs Config { get; set; } - public List InboundNodes { get; set; } + public List InboundNodes { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index fa965aa4..abfb235b 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -9,7 +9,10 @@ namespace Tensorflow.Keras.Saving { public string Name { get; set; } public List Layers { get; set; } - public List InputLayers { get; set; } - public List OutputLayers { get; set; } + public List InputLayers { get; set; } + public List OutputLayers { get; set; } + + public override string ToString() + => $"{Name}, {Layers.Count} Layers, {InputLayers.Count} Input Layers, {OutputLayers.Count} Output Layers"; } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs index 732d9d4d..3132248e 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -9,5 +9,8 @@ namespace Tensorflow.Keras.Saving public string Name { get; set; } public int NodeIndex { get; set; } public int TensorIndex { get; set; } + + public override string ToString() + => $"{Name}, {NodeIndex}, {TensorIndex}"; } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs b/src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs new file mode 100644 index 00000000..0002aed1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Functional + { + /// + /// Adds layers that are not connected to the outputs to the model. + /// + /// + public void connect_ancillary_layers(Dictionary created_layers) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs new file mode 100644 index 00000000..b0d1b2b6 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -0,0 +1,140 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Functional + { + public static Functional from_config(ModelConfig config) + { + var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); + var model = new Functional(input_tensors, output_tensors, name: config.Name); + model.connect_ancillary_layers(created_layers); + return model; + } + + /// + /// Reconstructs graph from config object. + /// + /// + /// + static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config) + { + // Layer instances created during the graph reconstruction process. + var created_layers = new Dictionary(); + var node_index_map = new Dictionary<(string, int), int>(); + var node_count_by_layer = new Dictionary(); + var unprocessed_nodes = new Dictionary(); + // First, we create all layers and enqueue nodes to be processed + foreach (var layer_data in config.Layers) + process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); + + // Then we process nodes in order of layer depth. + // Nodes that cannot yet be processed (if the inbound node + // does not yet exist) are re-enqueued, and the process + // is repeated until all nodes are processed. + while (unprocessed_nodes.Count > 0) + { + foreach(var layer_data in config.Layers) + { + var layer = created_layers[layer_data.Name]; + if (unprocessed_nodes.ContainsKey(layer)) + { + var node_data = unprocessed_nodes[layer]; + // foreach (var node_data in unprocessed_nodes[layer]) + { + process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map); + unprocessed_nodes.Remove(layer); + } + } + } + } + + var input_tensors = new List(); + foreach (var layer_data in config.InputLayers) + { + var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); + var layer = created_layers[layer_name]; + var layer_output_tensors = layer.InboundNodes[node_index].Outputs; + input_tensors.append(layer_output_tensors[tensor_index]); + } + + var output_tensors = new List(); + foreach (var layer_data in config.OutputLayers) + { + var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); + var layer = created_layers[layer_name]; + var layer_output_tensors = layer.InboundNodes[node_index].Outputs; + output_tensors.append(layer_output_tensors[tensor_index]); + } + + return (input_tensors, output_tensors, created_layers); + } + + static void process_layer(Dictionary created_layers, + LayerConfig layer_data, + Dictionary unprocessed_nodes, + Dictionary node_count_by_layer) + { + ILayer layer = null; + var layer_name = layer_data.Name; + if (created_layers.ContainsKey(layer_name)) + layer = created_layers[layer_name]; + else + { + layer = layer_data.ClassName switch + { + "InputLayer" => InputLayer.from_config(layer_data.Config), + "Dense" => Dense.from_config(layer_data.Config), + _ => throw new NotImplementedException("") + }; + + created_layers[layer_name] = layer; + } + node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0; + + var inbound_nodes_data = layer_data.InboundNodes; + foreach (var node_data in inbound_nodes_data) + { + if (!unprocessed_nodes.ContainsKey(layer)) + unprocessed_nodes[layer] = node_data; + else + unprocessed_nodes.Add(layer, node_data); + } + } + + static void process_node(ILayer layer, + NodeConfig node_data, + Dictionary created_layers, + Dictionary node_count_by_layer, + Dictionary<(string, int), int> node_index_map) + { + var input_tensors = new List(); + var inbound_layer_name = node_data.Name; + var inbound_node_index = node_data.NodeIndex; + var inbound_tensor_index = node_data.TensorIndex; + + var inbound_layer = created_layers[inbound_layer_name]; + var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; + input_tensors.Add(inbound_node.Outputs[inbound_node_index]); + + var output_tensors = layer.Apply(input_tensors); + + // Update node index map. + var output_index = output_tensors[0].KerasHistory.NodeIndex; + node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index; + node_count_by_layer[layer] += 1; + } + + static bool _should_skip_first_node(ILayer layer) + { + return layer is Functional && layer.Layers[0] is InputLayer; + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 96ea11b6..6615810b 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -44,14 +44,14 @@ namespace Tensorflow.Keras.Engine var layer_configs = new List(); foreach (var layer in _layers) { - var filtered_inbound_nodes = new List(); + var filtered_inbound_nodes = new List(); foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) { var node_key = _make_node_key(layer.Name, original_node_index); if (NetworkNodes.Contains(node_key) && !node.is_input) { var node_data = node.serialize(_make_node_key, node_conversion_map); - throw new NotImplementedException(""); + filtered_inbound_nodes.append(node_data); } } @@ -62,12 +62,42 @@ namespace Tensorflow.Keras.Engine } config.Layers = layer_configs; - return config; - } + // Gather info about inputs and outputs. + var model_inputs = new List(); + foreach (var i in range(_input_layers.Count)) + { + var (layer, node_index, tensor_index) = _input_coordinates[i]; + var node_key = _make_node_key(layer.Name, node_index); + if (!NetworkNodes.Contains(node_key)) + continue; + var new_node_index = node_conversion_map[node_key]; + model_inputs.append(new NodeConfig + { + Name = layer.Name, + NodeIndex = new_node_index, + TensorIndex = tensor_index + }); + } + config.InputLayers = model_inputs; - bool _should_skip_first_node(ILayer layer) - { - return layer is Functional && layer.Layers[0] is InputLayer; + var model_outputs = new List(); + foreach (var i in range(_output_layers.Count)) + { + var (layer, node_index, tensor_index) = _output_coordinates[i]; + var node_key = _make_node_key(layer.Name, node_index); + if (!NetworkNodes.Contains(node_key)) + continue; + var new_node_index = node_conversion_map[node_key]; + model_outputs.append(new NodeConfig + { + Name = layer.Name, + NodeIndex = new_node_index, + TensorIndex = tensor_index + }); + } + config.OutputLayers = model_outputs; + + return config; } string _make_node_key(string layer_name, int node_index) diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 93c9f91f..c434cea1 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -244,6 +244,6 @@ namespace Tensorflow.Keras.Engine } public virtual LayerArgs get_config() - => throw new NotImplementedException(""); + => args; } } diff --git a/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs index 05d544f8..7c8c805b 100644 --- a/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Node.Serialize.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.Saving; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { @@ -10,9 +12,19 @@ namespace Tensorflow.Keras.Engine /// Serializes `Node` for Functional API's `get_config`. /// /// - public NodeConfig serialize(Func make_node_key, Dictionary node_conversion_map) + public List serialize(Func make_node_key, Dictionary node_conversion_map) { - throw new NotImplementedException(""); + return KerasInputs.Select(x => { + var kh = x.KerasHistory; + var node_key = make_node_key(kh.Layer.Name, kh.NodeIndex); + var new_node_index = node_conversion_map.Get(node_key, 0); + return new NodeConfig + { + Name = kh.Layer.Name, + NodeIndex = new_node_index, + TensorIndex = kh.TensorIndex + }; + }).ToList(); } } } diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 5455148f..6cb733d3 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -5,7 +5,9 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Models; using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras { @@ -21,6 +23,7 @@ namespace Tensorflow.Keras public BackendImpl backend { get; } = new BackendImpl(); public OptimizerApi optimizers { get; } = new OptimizerApi(); public MetricsApi metrics { get; } = new MetricsApi(); + public ModelsApi models { get; } = new ModelsApi(); public Sequential Sequential(List layers = null, string name = null) diff --git a/src/TensorFlowNET.Keras/Layers/Dense.cs b/src/TensorFlowNET.Keras/Layers/Dense.cs index 4c864ad1..a01f3df7 100644 --- a/src/TensorFlowNET.Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Dense.cs @@ -84,5 +84,10 @@ namespace Tensorflow.Keras.Layers return outputs; } + + public static Dense from_config(LayerArgs args) + { + return new Dense(args as DenseArgs); + } } } diff --git a/src/TensorFlowNET.Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/InputLayer.cs index 32b566ea..49814f42 100644 --- a/src/TensorFlowNET.Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/InputLayer.cs @@ -101,7 +101,9 @@ namespace Tensorflow.Keras.Layers tf.Context.restore_mode(); } - public override LayerArgs get_config() - => args; + public static InputLayer from_config(LayerArgs args) + { + return new InputLayer(args as InputLayerArgs); + } } } diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs new file mode 100644 index 00000000..b575df27 --- /dev/null +++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Models +{ + public class ModelsApi + { + public Functional from_config(ModelConfig config) + => Functional.from_config(config); + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs index 886e30fc..901ecf02 100644 --- a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs @@ -10,11 +10,13 @@ namespace TensorFlowNET.UnitTest.Keras [TestClass] public class ModelSaveTest : EagerModeTestBase { - [TestMethod, Ignore] + [TestMethod] public void GetAndFromConfig() { var model = GetFunctionalModel(); var config = model.get_config(); + var new_model = keras.models.from_config(config); + Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); } Functional GetFunctionalModel()