diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index aca45146..4e23b11a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -8,6 +8,7 @@ namespace Tensorflow.Keras { string Name { get; } bool Trainable { get; } + bool Built { get; } List Layers { get; } List InboundNodes { get; } List OutboundNodes { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 0dd40096..42afc262 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -75,6 +75,8 @@ namespace Tensorflow public TensorShape BatchInputShape => throw new NotImplementedException(); public TF_DataType DType => throw new NotImplementedException(); + protected bool built = false; + public bool Built => built; public RnnCell(bool trainable = true, string name = null, diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index fc5d3de9..40ca550c 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -44,6 +44,7 @@ namespace Tensorflow.Keras.Engine /// the layer's weights. /// protected bool built; + public bool Built => built; public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index d06810f5..2b37d2bf 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -30,18 +30,13 @@ namespace Tensorflow.Keras.Engine public class Sequential : Functional { SequentialArgs args; - bool _is_graph_network; - Tensors inputs; - Tensors outputs; bool _compute_output_and_mask_jointly; bool _auto_track_sub_layers; TensorShape _inferred_input_shape; bool _has_explicit_input_shape; - TF_DataType _input_dtype; public TensorShape output_shape => outputs[0].TensorShape; - bool built = false; public Sequential(SequentialArgs args) : base(args.Inputs, args.Outputs, name: args.Name) diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs index b575df27..73b77bc4 100644 --- a/src/TensorFlowNET.Keras/Models/ModelsApi.cs +++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; +using System.IO; using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; namespace Tensorflow.Keras.Models { @@ -10,5 +12,21 @@ namespace Tensorflow.Keras.Models { public Functional from_config(ModelConfig config) => Functional.from_config(config); + + public void load_model(string filepath, bool compile = true) + { + var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); + var saved_mode = SavedModel.Parser.ParseFrom(bytes); + + var meta_graph_def = saved_mode.MetaGraphs[0]; + var object_graph_def = meta_graph_def.ObjectGraphDef; + + bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); + var metadata = SavedMetadata.Parser.ParseFrom(bytes); + + // Recreate layers and metrics using the info stored in the metadata. + var keras_loader = new KerasObjectLoader(metadata, object_graph_def); + keras_loader.load_layers(compile: compile); + } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs new file mode 100644 index 00000000..7646695b --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs @@ -0,0 +1,23 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public class KerasMetaData + { + public string Name { get; set; } + [JsonProperty("class_name")] + public string ClassName { get; set; } + [JsonProperty("is_graph_network")] + public bool IsGraphNetwork { get; set; } + [JsonProperty("shared_object_id")] + public int SharedObjectId { get; set; } + [JsonProperty("must_restore_from_config")] + public bool MustRestoreFromConfig { get; set; } + public ModelConfig Config { get; set; } + [JsonProperty("build_input_shape")] + public TensorShapeConfig BuildInputShape { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs new file mode 100644 index 00000000..82722cc1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -0,0 +1,161 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Saving +{ + public class KerasObjectLoader + { + SavedMetadata _metadata; + SavedObjectGraph _proto; + Dictionary _node_paths = new Dictionary(); + Dictionary model_layer_dependencies = new Dictionary(); + List _traversed_nodes_from_config = new List(); + + public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) + { + _metadata = metadata; + _proto = object_graph_def; + _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); + } + + /// + /// Load all layer nodes from the metadata. + /// + /// + public void load_layers(bool compile = true) + { + var metric_list = new List(); + foreach (var node_metadata in _metadata.Nodes) + { + if (node_metadata.Identifier == "_tf_keras_metric") + { + metric_list.Add(node_metadata); + continue; + } + + _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); + } + } + + void _load_layer(int node_id, string identifier, string metadata_json) + { + metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); + var metadata = JsonConvert.DeserializeObject(metadata_json); + _revive_from_config(identifier, metadata, node_id); + } + + /// + /// Revives a layer/model from config, or returns None. + /// + /// + /// + /// + void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) + { + var obj = _revive_graph_network(identifier, metadata, node_id); + obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); + _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); + } + + Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) + { + var config = metadata.Config; + var class_name = metadata.ClassName; + Model model = null; + if (class_name == "Sequential") + { + model = new Sequential(new SequentialArgs + { + Name = config.Name + }); + } + else if (class_name == "Functional") + { + throw new NotImplementedException(""); + } + + if (!metadata.IsGraphNetwork) + return null; + + // Record this model and its layers. This will later be used to reconstruct + // the model. + var layers = _get_child_layer_node_ids(node_id); + model_layer_dependencies[node_id] = (model, layers); + return model; + } + + Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) + { + var config = metadata.Config; + var class_name = metadata.ClassName; + var shared_object_id = metadata.SharedObjectId; + var must_restore_from_config = metadata.MustRestoreFromConfig; + + return null; + } + + /// + /// Returns the node ids of each layer in a Sequential/Functional model. + /// + /// + int[] _get_child_layer_node_ids(int node_id) + { + int num_layers = 0; + Dictionary child_layers = new Dictionary(); + foreach (var child in _proto.Nodes[node_id].Children) + { + var m = Regex.Match(child.LocalName, @"layer-(\d+)"); + if (!m.Success) + continue; + var layer_n = int.Parse(m.Groups[1].Value); + num_layers = max(layer_n + 1, num_layers); + child_layers[layer_n] = child.NodeId; + } + + var ordered = new List(); + foreach (var n in range(num_layers)) + { + if (child_layers.ContainsKey(n)) + ordered.Add(child_layers[n]); + else + break; + } + return ordered.ToArray(); + } + + /// + /// Recursively records objects recreated from config. + /// + /// + /// + /// + void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) + { + if (_traversed_nodes_from_config.Contains(node_id)) + return; + var parent_path = _node_paths[node_id]; + _traversed_nodes_from_config.Add(node_id); + if (!obj.Built) + { + var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); + var metadata = JsonConvert.DeserializeObject(metadata_json); + _try_build_layer(obj, node_id, metadata.BuildInputShape); + } + } + + bool _try_build_layer(Model obj, int node_id, TensorShape build_input_shape) + { + if (obj.Built) + return true; + + return false; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs new file mode 100644 index 00000000..17772b8e --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Saving +{ + public class TensorShapeConfig + { + public string ClassName { get; set; } + public int?[] Items { get; set; } + + public static implicit operator TensorShape(TensorShapeConfig shape) + => new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs b/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs index f9ae33b6..bdb06da7 100644 --- a/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/OutputTest.cs @@ -8,7 +8,7 @@ using static Tensorflow.Binding; using static Tensorflow.KerasApi; using Tensorflow.Keras; -namespace Tensorflow.Keras.UnitTest +namespace TensorFlowNET.Keras.UnitTest { [TestClass] public class OutputTest