@@ -8,6 +8,7 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
string Name { get; } | string Name { get; } | ||||
bool Trainable { get; } | bool Trainable { get; } | ||||
bool Built { get; } | |||||
List<ILayer> Layers { get; } | List<ILayer> Layers { get; } | ||||
List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
@@ -75,6 +75,8 @@ namespace Tensorflow | |||||
public TensorShape BatchInputShape => throw new NotImplementedException(); | public TensorShape BatchInputShape => throw new NotImplementedException(); | ||||
public TF_DataType DType => throw new NotImplementedException(); | public TF_DataType DType => throw new NotImplementedException(); | ||||
protected bool built = false; | |||||
public bool Built => built; | |||||
public RnnCell(bool trainable = true, | public RnnCell(bool trainable = true, | ||||
string name = null, | string name = null, | ||||
@@ -44,6 +44,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// the layer's weights. | /// the layer's weights. | ||||
/// </summary> | /// </summary> | ||||
protected bool built; | protected bool built; | ||||
public bool Built => built; | |||||
public bool Trainable => args.Trainable; | public bool Trainable => args.Trainable; | ||||
public TF_DataType DType => args.DType; | public TF_DataType DType => args.DType; | ||||
@@ -30,18 +30,13 @@ namespace Tensorflow.Keras.Engine | |||||
public class Sequential : Functional | public class Sequential : Functional | ||||
{ | { | ||||
SequentialArgs args; | SequentialArgs args; | ||||
bool _is_graph_network; | |||||
Tensors inputs; | |||||
Tensors outputs; | |||||
bool _compute_output_and_mask_jointly; | bool _compute_output_and_mask_jointly; | ||||
bool _auto_track_sub_layers; | bool _auto_track_sub_layers; | ||||
TensorShape _inferred_input_shape; | TensorShape _inferred_input_shape; | ||||
bool _has_explicit_input_shape; | bool _has_explicit_input_shape; | ||||
TF_DataType _input_dtype; | |||||
public TensorShape output_shape => outputs[0].TensorShape; | public TensorShape output_shape => outputs[0].TensorShape; | ||||
bool built = false; | |||||
public Sequential(SequentialArgs args) | public Sequential(SequentialArgs args) | ||||
: base(args.Inputs, args.Outputs, name: args.Name) | : base(args.Inputs, args.Outputs, name: args.Name) | ||||
@@ -1,8 +1,10 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||||
namespace Tensorflow.Keras.Models | namespace Tensorflow.Keras.Models | ||||
{ | { | ||||
@@ -10,5 +12,21 @@ namespace Tensorflow.Keras.Models | |||||
{ | { | ||||
public Functional from_config(ModelConfig config) | public Functional from_config(ModelConfig config) | ||||
=> Functional.from_config(config); | => Functional.from_config(config); | ||||
public void load_model(string filepath, bool compile = true) | |||||
{ | |||||
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); | |||||
var saved_mode = SavedModel.Parser.ParseFrom(bytes); | |||||
var meta_graph_def = saved_mode.MetaGraphs[0]; | |||||
var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); | |||||
var metadata = SavedMetadata.Parser.ParseFrom(bytes); | |||||
// Recreate layers and metrics using the info stored in the metadata. | |||||
var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
keras_loader.load_layers(compile: compile); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,23 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class KerasMetaData | |||||
{ | |||||
public string Name { get; set; } | |||||
[JsonProperty("class_name")] | |||||
public string ClassName { get; set; } | |||||
[JsonProperty("is_graph_network")] | |||||
public bool IsGraphNetwork { get; set; } | |||||
[JsonProperty("shared_object_id")] | |||||
public int SharedObjectId { get; set; } | |||||
[JsonProperty("must_restore_from_config")] | |||||
public bool MustRestoreFromConfig { get; set; } | |||||
public ModelConfig Config { get; set; } | |||||
[JsonProperty("build_input_shape")] | |||||
public TensorShapeConfig BuildInputShape { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,161 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text.RegularExpressions; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | |||||
using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class KerasObjectLoader | |||||
{ | |||||
SavedMetadata _metadata; | |||||
SavedObjectGraph _proto; | |||||
Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||||
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); | |||||
List<int> _traversed_nodes_from_config = new List<int>(); | |||||
public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | |||||
{ | |||||
_metadata = metadata; | |||||
_proto = object_graph_def; | |||||
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | |||||
} | |||||
/// <summary> | |||||
/// Load all layer nodes from the metadata. | |||||
/// </summary> | |||||
/// <param name="compile"></param> | |||||
public void load_layers(bool compile = true) | |||||
{ | |||||
var metric_list = new List<ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject>(); | |||||
foreach (var node_metadata in _metadata.Nodes) | |||||
{ | |||||
if (node_metadata.Identifier == "_tf_keras_metric") | |||||
{ | |||||
metric_list.Add(node_metadata); | |||||
continue; | |||||
} | |||||
_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||||
} | |||||
} | |||||
void _load_layer(int node_id, string identifier, string metadata_json) | |||||
{ | |||||
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||||
_revive_from_config(identifier, metadata, node_id); | |||||
} | |||||
/// <summary> | |||||
/// Revives a layer/model from config, or returns None. | |||||
/// </summary> | |||||
/// <param name="identifier"></param> | |||||
/// <param name="metadata"></param> | |||||
/// <param name="node_id"></param> | |||||
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||||
{ | |||||
var obj = _revive_graph_network(identifier, metadata, node_id); | |||||
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||||
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | |||||
} | |||||
Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | |||||
{ | |||||
var config = metadata.Config; | |||||
var class_name = metadata.ClassName; | |||||
Model model = null; | |||||
if (class_name == "Sequential") | |||||
{ | |||||
model = new Sequential(new SequentialArgs | |||||
{ | |||||
Name = config.Name | |||||
}); | |||||
} | |||||
else if (class_name == "Functional") | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
if (!metadata.IsGraphNetwork) | |||||
return null; | |||||
// Record this model and its layers. This will later be used to reconstruct | |||||
// the model. | |||||
var layers = _get_child_layer_node_ids(node_id); | |||||
model_layer_dependencies[node_id] = (model, layers); | |||||
return model; | |||||
} | |||||
Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||||
{ | |||||
var config = metadata.Config; | |||||
var class_name = metadata.ClassName; | |||||
var shared_object_id = metadata.SharedObjectId; | |||||
var must_restore_from_config = metadata.MustRestoreFromConfig; | |||||
return null; | |||||
} | |||||
/// <summary> | |||||
/// Returns the node ids of each layer in a Sequential/Functional model. | |||||
/// </summary> | |||||
/// <param name="node_id"></param> | |||||
int[] _get_child_layer_node_ids(int node_id) | |||||
{ | |||||
int num_layers = 0; | |||||
Dictionary<int, int> child_layers = new Dictionary<int, int>(); | |||||
foreach (var child in _proto.Nodes[node_id].Children) | |||||
{ | |||||
var m = Regex.Match(child.LocalName, @"layer-(\d+)"); | |||||
if (!m.Success) | |||||
continue; | |||||
var layer_n = int.Parse(m.Groups[1].Value); | |||||
num_layers = max(layer_n + 1, num_layers); | |||||
child_layers[layer_n] = child.NodeId; | |||||
} | |||||
var ordered = new List<int>(); | |||||
foreach (var n in range(num_layers)) | |||||
{ | |||||
if (child_layers.ContainsKey(n)) | |||||
ordered.Add(child_layers[n]); | |||||
else | |||||
break; | |||||
} | |||||
return ordered.ToArray(); | |||||
} | |||||
/// <summary> | |||||
/// Recursively records objects recreated from config. | |||||
/// </summary> | |||||
/// <param name="obj"></param> | |||||
/// <param name="proto"></param> | |||||
/// <param name="node_id"></param> | |||||
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) | |||||
{ | |||||
if (_traversed_nodes_from_config.Contains(node_id)) | |||||
return; | |||||
var parent_path = _node_paths[node_id]; | |||||
_traversed_nodes_from_config.Add(node_id); | |||||
if (!obj.Built) | |||||
{ | |||||
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||||
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||||
_try_build_layer(obj, node_id, metadata.BuildInputShape); | |||||
} | |||||
} | |||||
bool _try_build_layer(Model obj, int node_id, TensorShape build_input_shape) | |||||
{ | |||||
if (obj.Built) | |||||
return true; | |||||
return false; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
namespace Tensorflow.Keras.Saving | |||||
{ | |||||
public class TensorShapeConfig | |||||
{ | |||||
public string ClassName { get; set; } | |||||
public int?[] Items { get; set; } | |||||
public static implicit operator TensorShape(TensorShapeConfig shape) | |||||
=> new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
} | |||||
} |
@@ -8,7 +8,7 @@ using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
namespace Tensorflow.Keras.UnitTest | |||||
namespace TensorFlowNET.Keras.UnitTest | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class OutputTest | public class OutputTest | ||||