Browse Source

Add save and restore model from config.

tags/v0.30
Oceania2018 4 years ago
parent
commit
8b9fca47e9
18 changed files with 269 additions and 26 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  3. +6
    -8
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/INode.cs
  5. +2
    -0
      src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
  7. +5
    -2
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  8. +3
    -0
      src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
  9. +23
    -0
      src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs
  10. +140
    -0
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  11. +37
    -7
      src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
  12. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  13. +14
    -2
      src/TensorFlowNET.Keras/Engine/Node.Serialize.cs
  14. +3
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  15. +5
    -0
      src/TensorFlowNET.Keras/Layers/Dense.cs
  16. +4
    -2
      src/TensorFlowNET.Keras/Layers/InputLayer.cs
  17. +14
    -0
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  18. +3
    -1
      test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs

+ 6
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -58,6 +58,12 @@ namespace Tensorflow
public static void append<T>(this IList<T> list, T element)
=> list.Insert(list.Count, element);

public static void append<T>(this IList<T> list, IList<T> elements)
{
for (int i = 0; i < elements.Count(); i++)
list.Insert(list.Count, elements[i]);
}

public static T[] concat<T>(this IList<T> list1, IList<T> list2)
{
var list = new List<T>();


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow.Eager
}*/
}

Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
// Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
if (!should_record) return should_record;

Tensor[] op_outputs;


+ 6
- 8
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -761,13 +761,6 @@ namespace Tensorflow.Gradients
{
sx = array_ops.shape(x);
sy = array_ops.shape(y);

var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
return new[]
{
(sx, rx, true),
(sy, ry, true)
};
}
else
{
@@ -775,7 +768,12 @@ namespace Tensorflow.Gradients
sy = array_ops.shape_internal(y, optimize: false);
}

throw new NotImplementedException("");
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
return new[]
{
(sx, rx, true),
(sy, ry, true)
};
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/INode.cs View File

@@ -13,6 +13,6 @@ namespace Tensorflow.Keras.Engine
INode[] ParentNodes { get; }
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
bool is_input { get; }
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);
List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);
}
}

+ 2
- 0
src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs View File

@@ -8,7 +8,9 @@
ILayer layer;
public ILayer Layer => layer;
int node_index;
public int NodeIndex => node_index;
int tensor_index;
public int TensorIndex => tensor_index;
Tensor tensor;

public KerasHistory(ILayer layer, int node_index, int tensor_index, Tensor tensor)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs View File

@@ -11,6 +11,6 @@ namespace Tensorflow.Keras.Saving
public string Name { get; set; }
public string ClassName { get; set; }
public LayerArgs Config { get; set; }
public List<INode> InboundNodes { get; set; }
public List<NodeConfig> InboundNodes { get; set; }
}
}

+ 5
- 2
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -9,7 +9,10 @@ namespace Tensorflow.Keras.Saving
{
public string Name { get; set; }
public List<LayerConfig> Layers { get; set; }
public List<ILayer> InputLayers { get; set; }
public List<ILayer> OutputLayers { get; set; }
public List<NodeConfig> InputLayers { get; set; }
public List<NodeConfig> OutputLayers { get; set; }

public override string ToString()
=> $"{Name}, {Layers.Count} Layers, {InputLayers.Count} Input Layers, {OutputLayers.Count} Output Layers";
}
}

+ 3
- 0
src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs View File

@@ -9,5 +9,8 @@ namespace Tensorflow.Keras.Saving
public string Name { get; set; }
public int NodeIndex { get; set; }
public int TensorIndex { get; set; }

public override string ToString()
=> $"{Name}, {NodeIndex}, {TensorIndex}";
}
}

+ 23
- 0
src/TensorFlowNET.Keras/Engine/Functional.ConnectAncillaryLayers.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
/// <summary>
/// Adds layers that are not connected to the outputs to the model.
/// </summary>
/// <param name="created_layers"></param>
public void connect_ancillary_layers(Dictionary<string, ILayer> created_layers)
{
}
}
}

+ 140
- 0
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -0,0 +1,140 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
public static Functional from_config(ModelConfig config)
{
var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config);
var model = new Functional(input_tensors, output_tensors, name: config.Name);
model.connect_ancillary_layers(created_layers);
return model;
}

/// <summary>
/// Reconstructs graph from config object.
/// </summary>
/// <param name="config"></param>
/// <returns></returns>
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
{
// Layer instances created during the graph reconstruction process.
var created_layers = new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
// First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);

// Then we process nodes in order of layer depth.
// Nodes that cannot yet be processed (if the inbound node
// does not yet exist) are re-enqueued, and the process
// is repeated until all nodes are processed.
while (unprocessed_nodes.Count > 0)
{
foreach(var layer_data in config.Layers)
{
var layer = created_layers[layer_data.Name];
if (unprocessed_nodes.ContainsKey(layer))
{
var node_data = unprocessed_nodes[layer];
// foreach (var node_data in unprocessed_nodes[layer])
{
process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map);
unprocessed_nodes.Remove(layer);
}
}
}
}

var input_tensors = new List<Tensor>();
foreach (var layer_data in config.InputLayers)
{
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
var layer = created_layers[layer_name];
var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
input_tensors.append(layer_output_tensors[tensor_index]);
}

var output_tensors = new List<Tensor>();
foreach (var layer_data in config.OutputLayers)
{
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
var layer = created_layers[layer_name];
var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
output_tensors.append(layer_output_tensors[tensor_index]);
}

return (input_tensors, output_tensors, created_layers);
}

static void process_layer(Dictionary<string, ILayer> created_layers,
LayerConfig layer_data,
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
Dictionary<ILayer, int> node_count_by_layer)
{
ILayer layer = null;
var layer_name = layer_data.Name;
if (created_layers.ContainsKey(layer_name))
layer = created_layers[layer_name];
else
{
layer = layer_data.ClassName switch
{
"InputLayer" => InputLayer.from_config(layer_data.Config),
"Dense" => Dense.from_config(layer_data.Config),
_ => throw new NotImplementedException("")
};

created_layers[layer_name] = layer;
}
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;

var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data)
{
if (!unprocessed_nodes.ContainsKey(layer))
unprocessed_nodes[layer] = node_data;
else
unprocessed_nodes.Add(layer, node_data);
}
}

static void process_node(ILayer layer,
NodeConfig node_data,
Dictionary<string, ILayer> created_layers,
Dictionary<ILayer, int> node_count_by_layer,
Dictionary<(string, int), int> node_index_map)
{
var input_tensors = new List<Tensor>();
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;

var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);

var output_tensors = layer.Apply(input_tensors);

// Update node index map.
var output_index = output_tensors[0].KerasHistory.NodeIndex;
node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index;
node_count_by_layer[layer] += 1;
}

static bool _should_skip_first_node(ILayer layer)
{
return layer is Functional && layer.Layers[0] is InputLayer;
}
}
}

+ 37
- 7
src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs View File

@@ -44,14 +44,14 @@ namespace Tensorflow.Keras.Engine
var layer_configs = new List<LayerConfig>();
foreach (var layer in _layers)
{
var filtered_inbound_nodes = new List<INode>();
var filtered_inbound_nodes = new List<NodeConfig>();
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
{
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key) && !node.is_input)
{
var node_data = node.serialize(_make_node_key, node_conversion_map);
throw new NotImplementedException("");
filtered_inbound_nodes.append(node_data);
}
}

@@ -62,12 +62,42 @@ namespace Tensorflow.Keras.Engine
}
config.Layers = layer_configs;

return config;
}
// Gather info about inputs and outputs.
var model_inputs = new List<NodeConfig>();
foreach (var i in range(_input_layers.Count))
{
var (layer, node_index, tensor_index) = _input_coordinates[i];
var node_key = _make_node_key(layer.Name, node_index);
if (!NetworkNodes.Contains(node_key))
continue;
var new_node_index = node_conversion_map[node_key];
model_inputs.append(new NodeConfig
{
Name = layer.Name,
NodeIndex = new_node_index,
TensorIndex = tensor_index
});
}
config.InputLayers = model_inputs;

bool _should_skip_first_node(ILayer layer)
{
return layer is Functional && layer.Layers[0] is InputLayer;
var model_outputs = new List<NodeConfig>();
foreach (var i in range(_output_layers.Count))
{
var (layer, node_index, tensor_index) = _output_coordinates[i];
var node_key = _make_node_key(layer.Name, node_index);
if (!NetworkNodes.Contains(node_key))
continue;
var new_node_index = node_conversion_map[node_key];
model_outputs.append(new NodeConfig
{
Name = layer.Name,
NodeIndex = new_node_index,
TensorIndex = tensor_index
});
}
config.OutputLayers = model_outputs;

return config;
}

string _make_node_key(string layer_name, int node_index)


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -244,6 +244,6 @@ namespace Tensorflow.Keras.Engine
}

public virtual LayerArgs get_config()
=> throw new NotImplementedException("");
=> args;
}
}

+ 14
- 2
src/TensorFlowNET.Keras/Engine/Node.Serialize.cs View File

@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
@@ -10,9 +12,19 @@ namespace Tensorflow.Keras.Engine
/// Serializes `Node` for Functional API's `get_config`.
/// </summary>
/// <returns></returns>
public NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map)
public List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map)
{
throw new NotImplementedException("");
return KerasInputs.Select(x => {
var kh = x.KerasHistory;
var node_key = make_node_key(kh.Layer.Name, kh.NodeIndex);
var new_node_index = node_conversion_map.Get(node_key, 0);
return new NodeConfig
{
Name = kh.Layer.Name,
NodeIndex = new_node_index,
TensorIndex = kh.TensorIndex
};
}).ToList();
}
}
}

+ 3
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -5,7 +5,9 @@ using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras
{
@@ -21,6 +23,7 @@ namespace Tensorflow.Keras
public BackendImpl backend { get; } = new BackendImpl();
public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();
public ModelsApi models { get; } = new ModelsApi();

public Sequential Sequential(List<ILayer> layers = null,
string name = null)


+ 5
- 0
src/TensorFlowNET.Keras/Layers/Dense.cs View File

@@ -84,5 +84,10 @@ namespace Tensorflow.Keras.Layers

return outputs;
}

public static Dense from_config(LayerArgs args)
{
return new Dense(args as DenseArgs);
}
}
}

+ 4
- 2
src/TensorFlowNET.Keras/Layers/InputLayer.cs View File

@@ -101,7 +101,9 @@ namespace Tensorflow.Keras.Layers
tf.Context.restore_mode();
}

public override LayerArgs get_config()
=> args;
public static InputLayer from_config(LayerArgs args)
{
return new InputLayer(args as InputLayerArgs);
}
}
}

+ 14
- 0
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Models
{
public class ModelsApi
{
public Functional from_config(ModelConfig config)
=> Functional.from_config(config);
}
}

+ 3
- 1
test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs View File

@@ -10,11 +10,13 @@ namespace TensorFlowNET.UnitTest.Keras
[TestClass]
public class ModelSaveTest : EagerModeTestBase
{
[TestMethod, Ignore]
[TestMethod]
public void GetAndFromConfig()
{
var model = GetFunctionalModel();
var config = model.get_config();
var new_model = keras.models.from_config(config);
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
}

Functional GetFunctionalModel()


Loading…
Cancel
Save