|
|
@@ -0,0 +1,161 @@ |
|
|
|
using Newtonsoft.Json; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text.RegularExpressions; |
|
|
|
using Tensorflow.Keras.ArgsDefinition; |
|
|
|
using Tensorflow.Keras.Engine; |
|
|
|
using ThirdParty.Tensorflow.Python.Keras.Protobuf; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Saving |
|
|
|
{ |
|
|
|
public class KerasObjectLoader |
|
|
|
{ |
|
|
|
SavedMetadata _metadata; |
|
|
|
SavedObjectGraph _proto; |
|
|
|
Dictionary<int, string> _node_paths = new Dictionary<int, string>(); |
|
|
|
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); |
|
|
|
List<int> _traversed_nodes_from_config = new List<int>(); |
|
|
|
|
|
|
|
public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) |
|
|
|
{ |
|
|
|
_metadata = metadata; |
|
|
|
_proto = object_graph_def; |
|
|
|
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Load all layer nodes from the metadata. |
|
|
|
/// </summary> |
|
|
|
/// <param name="compile"></param> |
|
|
|
public void load_layers(bool compile = true) |
|
|
|
{ |
|
|
|
var metric_list = new List<ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject>(); |
|
|
|
foreach (var node_metadata in _metadata.Nodes) |
|
|
|
{ |
|
|
|
if (node_metadata.Identifier == "_tf_keras_metric") |
|
|
|
{ |
|
|
|
metric_list.Add(node_metadata); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void _load_layer(int node_id, string identifier, string metadata_json) |
|
|
|
{ |
|
|
|
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); |
|
|
|
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); |
|
|
|
_revive_from_config(identifier, metadata, node_id); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Revives a layer/model from config, or returns None. |
|
|
|
/// </summary> |
|
|
|
/// <param name="identifier"></param> |
|
|
|
/// <param name="metadata"></param> |
|
|
|
/// <param name="node_id"></param> |
|
|
|
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) |
|
|
|
{ |
|
|
|
var obj = _revive_graph_network(identifier, metadata, node_id); |
|
|
|
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); |
|
|
|
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); |
|
|
|
} |
|
|
|
|
|
|
|
Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) |
|
|
|
{ |
|
|
|
var config = metadata.Config; |
|
|
|
var class_name = metadata.ClassName; |
|
|
|
Model model = null; |
|
|
|
if (class_name == "Sequential") |
|
|
|
{ |
|
|
|
model = new Sequential(new SequentialArgs |
|
|
|
{ |
|
|
|
Name = config.Name |
|
|
|
}); |
|
|
|
} |
|
|
|
else if (class_name == "Functional") |
|
|
|
{ |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
if (!metadata.IsGraphNetwork) |
|
|
|
return null; |
|
|
|
|
|
|
|
// Record this model and its layers. This will later be used to reconstruct |
|
|
|
// the model. |
|
|
|
var layers = _get_child_layer_node_ids(node_id); |
|
|
|
model_layer_dependencies[node_id] = (model, layers); |
|
|
|
return model; |
|
|
|
} |
|
|
|
|
|
|
|
Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) |
|
|
|
{ |
|
|
|
var config = metadata.Config; |
|
|
|
var class_name = metadata.ClassName; |
|
|
|
var shared_object_id = metadata.SharedObjectId; |
|
|
|
var must_restore_from_config = metadata.MustRestoreFromConfig; |
|
|
|
|
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Returns the node ids of each layer in a Sequential/Functional model. |
|
|
|
/// </summary> |
|
|
|
/// <param name="node_id"></param> |
|
|
|
int[] _get_child_layer_node_ids(int node_id) |
|
|
|
{ |
|
|
|
int num_layers = 0; |
|
|
|
Dictionary<int, int> child_layers = new Dictionary<int, int>(); |
|
|
|
foreach (var child in _proto.Nodes[node_id].Children) |
|
|
|
{ |
|
|
|
var m = Regex.Match(child.LocalName, @"layer-(\d+)"); |
|
|
|
if (!m.Success) |
|
|
|
continue; |
|
|
|
var layer_n = int.Parse(m.Groups[1].Value); |
|
|
|
num_layers = max(layer_n + 1, num_layers); |
|
|
|
child_layers[layer_n] = child.NodeId; |
|
|
|
} |
|
|
|
|
|
|
|
var ordered = new List<int>(); |
|
|
|
foreach (var n in range(num_layers)) |
|
|
|
{ |
|
|
|
if (child_layers.ContainsKey(n)) |
|
|
|
ordered.Add(child_layers[n]); |
|
|
|
else |
|
|
|
break; |
|
|
|
} |
|
|
|
return ordered.ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Recursively records objects recreated from config. |
|
|
|
/// </summary> |
|
|
|
/// <param name="obj"></param> |
|
|
|
/// <param name="proto"></param> |
|
|
|
/// <param name="node_id"></param> |
|
|
|
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) |
|
|
|
{ |
|
|
|
if (_traversed_nodes_from_config.Contains(node_id)) |
|
|
|
return; |
|
|
|
var parent_path = _node_paths[node_id]; |
|
|
|
_traversed_nodes_from_config.Add(node_id); |
|
|
|
if (!obj.Built) |
|
|
|
{ |
|
|
|
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); |
|
|
|
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); |
|
|
|
_try_build_layer(obj, node_id, metadata.BuildInputShape); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool _try_build_layer(Model obj, int node_id, TensorShape build_input_shape) |
|
|
|
{ |
|
|
|
if (obj.Built) |
|
|
|
return true; |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |