Browse Source

KerasObjectLoader

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
954462f4f8
9 changed files with 222 additions and 6 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +1
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs
  4. +0
    -5
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  5. +18
    -0
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  6. +23
    -0
      src/TensorFlowNET.Keras/Saving/KerasMetaData.cs
  7. +161
    -0
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  8. +15
    -0
      src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs
  9. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/OutputTest.cs

+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow.Keras
{
string Name { get; }
bool Trainable { get; }
bool Built { get; }
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }


+ 2
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -75,6 +75,8 @@ namespace Tensorflow
public TensorShape BatchInputShape => throw new NotImplementedException();

public TF_DataType DType => throw new NotImplementedException();
protected bool built = false;
public bool Built => built;

public RnnCell(bool trainable = true,
string name = null,


+ 1
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -44,6 +44,7 @@ namespace Tensorflow.Keras.Engine
/// the layer's weights.
/// </summary>
protected bool built;
public bool Built => built;
public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType;



+ 0
- 5
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -30,18 +30,13 @@ namespace Tensorflow.Keras.Engine
public class Sequential : Functional
{
SequentialArgs args;
bool _is_graph_network;
Tensors inputs;
Tensors outputs;

bool _compute_output_and_mask_jointly;
bool _auto_track_sub_layers;
TensorShape _inferred_input_shape;
bool _has_explicit_input_shape;
TF_DataType _input_dtype;
public TensorShape output_shape => outputs[0].TensorShape;
bool built = false;

public Sequential(SequentialArgs args)
: base(args.Inputs, args.Outputs, name: args.Name)


+ 18
- 0
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -1,8 +1,10 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;

namespace Tensorflow.Keras.Models
{
@@ -10,5 +12,21 @@ namespace Tensorflow.Keras.Models
{
public Functional from_config(ModelConfig config)
=> Functional.from_config(config);

public void load_model(string filepath, bool compile = true)
{
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb"));
var saved_mode = SavedModel.Parser.ParseFrom(bytes);
var meta_graph_def = saved_mode.MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;

bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb"));
var metadata = SavedMetadata.Parser.ParseFrom(bytes);

// Recreate layers and metrics using the info stored in the metadata.
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);
}
}
}

+ 23
- 0
src/TensorFlowNET.Keras/Saving/KerasMetaData.cs View File

@@ -0,0 +1,23 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving
{
public class KerasMetaData
{
public string Name { get; set; }
[JsonProperty("class_name")]
public string ClassName { get; set; }
[JsonProperty("is_graph_network")]
public bool IsGraphNetwork { get; set; }
[JsonProperty("shared_object_id")]
public int SharedObjectId { get; set; }
[JsonProperty("must_restore_from_config")]
public bool MustRestoreFromConfig { get; set; }
public ModelConfig Config { get; set; }
[JsonProperty("build_input_shape")]
public TensorShapeConfig BuildInputShape { get; set; }
}
}

+ 161
- 0
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -0,0 +1,161 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Saving
{
public class KerasObjectLoader
{
SavedMetadata _metadata;
SavedObjectGraph _proto;
Dictionary<int, string> _node_paths = new Dictionary<int, string>();
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>();
List<int> _traversed_nodes_from_config = new List<int>();

public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def)
{
_metadata = metadata;
_proto = object_graph_def;
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath);
}

/// <summary>
/// Load all layer nodes from the metadata.
/// </summary>
/// <param name="compile"></param>
public void load_layers(bool compile = true)
{
var metric_list = new List<ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject>();
foreach (var node_metadata in _metadata.Nodes)
{
if (node_metadata.Identifier == "_tf_keras_metric")
{
metric_list.Add(node_metadata);
continue;
}

_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata);
}
}

void _load_layer(int node_id, string identifier, string metadata_json)
{
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
_revive_from_config(identifier, metadata, node_id);
}

/// <summary>
/// Revives a layer/model from config, or returns None.
/// </summary>
/// <param name="identifier"></param>
/// <param name="metadata"></param>
/// <param name="node_id"></param>
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
{
var obj = _revive_graph_network(identifier, metadata, node_id);
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id);
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id);
}

Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id)
{
var config = metadata.Config;
var class_name = metadata.ClassName;
Model model = null;
if (class_name == "Sequential")
{
model = new Sequential(new SequentialArgs
{
Name = config.Name
});
}
else if (class_name == "Functional")
{
throw new NotImplementedException("");
}

if (!metadata.IsGraphNetwork)
return null;

// Record this model and its layers. This will later be used to reconstruct
// the model.
var layers = _get_child_layer_node_ids(node_id);
model_layer_dependencies[node_id] = (model, layers);
return model;
}

Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id)
{
var config = metadata.Config;
var class_name = metadata.ClassName;
var shared_object_id = metadata.SharedObjectId;
var must_restore_from_config = metadata.MustRestoreFromConfig;

return null;
}

/// <summary>
/// Returns the node ids of each layer in a Sequential/Functional model.
/// </summary>
/// <param name="node_id"></param>
int[] _get_child_layer_node_ids(int node_id)
{
int num_layers = 0;
Dictionary<int, int> child_layers = new Dictionary<int, int>();
foreach (var child in _proto.Nodes[node_id].Children)
{
var m = Regex.Match(child.LocalName, @"layer-(\d+)");
if (!m.Success)
continue;
var layer_n = int.Parse(m.Groups[1].Value);
num_layers = max(layer_n + 1, num_layers);
child_layers[layer_n] = child.NodeId;
}

var ordered = new List<int>();
foreach (var n in range(num_layers))
{
if (child_layers.ContainsKey(n))
ordered.Add(child_layers[n]);
else
break;
}
return ordered.ToArray();
}

/// <summary>
/// Recursively records objects recreated from config.
/// </summary>
/// <param name="obj"></param>
/// <param name="proto"></param>
/// <param name="node_id"></param>
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id)
{
if (_traversed_nodes_from_config.Contains(node_id))
return;
var parent_path = _node_paths[node_id];
_traversed_nodes_from_config.Add(node_id);
if (!obj.Built)
{
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
_try_build_layer(obj, node_id, metadata.BuildInputShape);
}
}

bool _try_build_layer(Model obj, int node_id, TensorShape build_input_shape)
{
if (obj.Built)
return true;

return false;
}
}
}

+ 15
- 0
src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
{
public class TensorShapeConfig
{
public string ClassName { get; set; }
public int?[] Items { get; set; }

public static implicit operator TensorShape(TensorShapeConfig shape)
=> new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());
}
}

+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/OutputTest.cs View File

@@ -8,7 +8,7 @@ using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;

namespace Tensorflow.Keras.UnitTest
namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class OutputTest


Loading…
Cancel
Save