@@ -58,6 +58,12 @@ namespace Tensorflow | |||
public static void append<T>(this IList<T> list, T element) | |||
=> list.Insert(list.Count, element); | |||
public static void append<T>(this IList<T> list, IList<T> elements) | |||
{ | |||
for (int i = 0; i < elements.Count(); i++) | |||
list.Insert(list.Count, elements[i]); | |||
} | |||
public static T[] concat<T>(this IList<T> list1, IList<T> list2) | |||
{ | |||
var list = new List<T>(); | |||
@@ -38,7 +38,7 @@ namespace Tensorflow.Eager | |||
}*/ | |||
} | |||
Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); | |||
// Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}"); | |||
if (!should_record) return should_record; | |||
Tensor[] op_outputs; | |||
@@ -761,13 +761,6 @@ namespace Tensorflow.Gradients | |||
{ | |||
sx = array_ops.shape(x); | |||
sy = array_ops.shape(y); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
return new[] | |||
{ | |||
(sx, rx, true), | |||
(sy, ry, true) | |||
}; | |||
} | |||
else | |||
{ | |||
@@ -775,7 +768,12 @@ namespace Tensorflow.Gradients | |||
sy = array_ops.shape_internal(y, optimize: false); | |||
} | |||
throw new NotImplementedException(""); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
return new[] | |||
{ | |||
(sx, rx, true), | |||
(sy, ry, true) | |||
}; | |||
} | |||
} | |||
} |
@@ -13,6 +13,6 @@ namespace Tensorflow.Keras.Engine | |||
INode[] ParentNodes { get; } | |||
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound(); | |||
bool is_input { get; } | |||
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); | |||
List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map); | |||
} | |||
} |
@@ -8,7 +8,9 @@ | |||
ILayer layer; | |||
public ILayer Layer => layer; | |||
int node_index; | |||
public int NodeIndex => node_index; | |||
int tensor_index; | |||
public int TensorIndex => tensor_index; | |||
Tensor tensor; | |||
public KerasHistory(ILayer layer, int node_index, int tensor_index, Tensor tensor) | |||
@@ -11,6 +11,6 @@ namespace Tensorflow.Keras.Saving | |||
public string Name { get; set; } | |||
public string ClassName { get; set; } | |||
public LayerArgs Config { get; set; } | |||
public List<INode> InboundNodes { get; set; } | |||
public List<NodeConfig> InboundNodes { get; set; } | |||
} | |||
} |
@@ -9,7 +9,10 @@ namespace Tensorflow.Keras.Saving | |||
{ | |||
public string Name { get; set; } | |||
public List<LayerConfig> Layers { get; set; } | |||
public List<ILayer> InputLayers { get; set; } | |||
public List<ILayer> OutputLayers { get; set; } | |||
public List<NodeConfig> InputLayers { get; set; } | |||
public List<NodeConfig> OutputLayers { get; set; } | |||
public override string ToString() | |||
=> $"{Name}, {Layers.Count} Layers, {InputLayers.Count} Input Layers, {OutputLayers.Count} Output Layers"; | |||
} | |||
} |
@@ -9,5 +9,8 @@ namespace Tensorflow.Keras.Saving | |||
public string Name { get; set; } | |||
public int NodeIndex { get; set; } | |||
public int TensorIndex { get; set; } | |||
public override string ToString() | |||
=> $"{Name}, {NodeIndex}, {TensorIndex}"; | |||
} | |||
} |
@@ -0,0 +1,23 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Functional | |||
{ | |||
/// <summary> | |||
/// Adds layers that are not connected to the outputs to the model. | |||
/// </summary> | |||
/// <param name="created_layers"></param> | |||
public void connect_ancillary_layers(Dictionary<string, ILayer> created_layers) | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,140 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Functional | |||
{ | |||
public static Functional from_config(ModelConfig config) | |||
{ | |||
var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); | |||
var model = new Functional(input_tensors, output_tensors, name: config.Name); | |||
model.connect_ancillary_layers(created_layers); | |||
return model; | |||
} | |||
/// <summary> | |||
/// Reconstructs graph from config object. | |||
/// </summary> | |||
/// <param name="config"></param> | |||
/// <returns></returns> | |||
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||
{ | |||
// Layer instances created during the graph reconstruction process. | |||
var created_layers = new Dictionary<string, ILayer>(); | |||
var node_index_map = new Dictionary<(string, int), int>(); | |||
var node_count_by_layer = new Dictionary<ILayer, int>(); | |||
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | |||
// First, we create all layers and enqueue nodes to be processed | |||
foreach (var layer_data in config.Layers) | |||
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); | |||
// Then we process nodes in order of layer depth. | |||
// Nodes that cannot yet be processed (if the inbound node | |||
// does not yet exist) are re-enqueued, and the process | |||
// is repeated until all nodes are processed. | |||
while (unprocessed_nodes.Count > 0) | |||
{ | |||
foreach(var layer_data in config.Layers) | |||
{ | |||
var layer = created_layers[layer_data.Name]; | |||
if (unprocessed_nodes.ContainsKey(layer)) | |||
{ | |||
var node_data = unprocessed_nodes[layer]; | |||
// foreach (var node_data in unprocessed_nodes[layer]) | |||
{ | |||
process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map); | |||
unprocessed_nodes.Remove(layer); | |||
} | |||
} | |||
} | |||
} | |||
var input_tensors = new List<Tensor>(); | |||
foreach (var layer_data in config.InputLayers) | |||
{ | |||
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); | |||
var layer = created_layers[layer_name]; | |||
var layer_output_tensors = layer.InboundNodes[node_index].Outputs; | |||
input_tensors.append(layer_output_tensors[tensor_index]); | |||
} | |||
var output_tensors = new List<Tensor>(); | |||
foreach (var layer_data in config.OutputLayers) | |||
{ | |||
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex); | |||
var layer = created_layers[layer_name]; | |||
var layer_output_tensors = layer.InboundNodes[node_index].Outputs; | |||
output_tensors.append(layer_output_tensors[tensor_index]); | |||
} | |||
return (input_tensors, output_tensors, created_layers); | |||
} | |||
static void process_layer(Dictionary<string, ILayer> created_layers, | |||
LayerConfig layer_data, | |||
Dictionary<ILayer, NodeConfig> unprocessed_nodes, | |||
Dictionary<ILayer, int> node_count_by_layer) | |||
{ | |||
ILayer layer = null; | |||
var layer_name = layer_data.Name; | |||
if (created_layers.ContainsKey(layer_name)) | |||
layer = created_layers[layer_name]; | |||
else | |||
{ | |||
layer = layer_data.ClassName switch | |||
{ | |||
"InputLayer" => InputLayer.from_config(layer_data.Config), | |||
"Dense" => Dense.from_config(layer_data.Config), | |||
_ => throw new NotImplementedException("") | |||
}; | |||
created_layers[layer_name] = layer; | |||
} | |||
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0; | |||
var inbound_nodes_data = layer_data.InboundNodes; | |||
foreach (var node_data in inbound_nodes_data) | |||
{ | |||
if (!unprocessed_nodes.ContainsKey(layer)) | |||
unprocessed_nodes[layer] = node_data; | |||
else | |||
unprocessed_nodes.Add(layer, node_data); | |||
} | |||
} | |||
static void process_node(ILayer layer, | |||
NodeConfig node_data, | |||
Dictionary<string, ILayer> created_layers, | |||
Dictionary<ILayer, int> node_count_by_layer, | |||
Dictionary<(string, int), int> node_index_map) | |||
{ | |||
var input_tensors = new List<Tensor>(); | |||
var inbound_layer_name = node_data.Name; | |||
var inbound_node_index = node_data.NodeIndex; | |||
var inbound_tensor_index = node_data.TensorIndex; | |||
var inbound_layer = created_layers[inbound_layer_name]; | |||
var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; | |||
input_tensors.Add(inbound_node.Outputs[inbound_node_index]); | |||
var output_tensors = layer.Apply(input_tensors); | |||
// Update node index map. | |||
var output_index = output_tensors[0].KerasHistory.NodeIndex; | |||
node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index; | |||
node_count_by_layer[layer] += 1; | |||
} | |||
static bool _should_skip_first_node(ILayer layer) | |||
{ | |||
return layer is Functional && layer.Layers[0] is InputLayer; | |||
} | |||
} | |||
} |
@@ -44,14 +44,14 @@ namespace Tensorflow.Keras.Engine | |||
var layer_configs = new List<LayerConfig>(); | |||
foreach (var layer in _layers) | |||
{ | |||
var filtered_inbound_nodes = new List<INode>(); | |||
var filtered_inbound_nodes = new List<NodeConfig>(); | |||
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||
{ | |||
var node_key = _make_node_key(layer.Name, original_node_index); | |||
if (NetworkNodes.Contains(node_key) && !node.is_input) | |||
{ | |||
var node_data = node.serialize(_make_node_key, node_conversion_map); | |||
throw new NotImplementedException(""); | |||
filtered_inbound_nodes.append(node_data); | |||
} | |||
} | |||
@@ -62,12 +62,42 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
config.Layers = layer_configs; | |||
return config; | |||
} | |||
// Gather info about inputs and outputs. | |||
var model_inputs = new List<NodeConfig>(); | |||
foreach (var i in range(_input_layers.Count)) | |||
{ | |||
var (layer, node_index, tensor_index) = _input_coordinates[i]; | |||
var node_key = _make_node_key(layer.Name, node_index); | |||
if (!NetworkNodes.Contains(node_key)) | |||
continue; | |||
var new_node_index = node_conversion_map[node_key]; | |||
model_inputs.append(new NodeConfig | |||
{ | |||
Name = layer.Name, | |||
NodeIndex = new_node_index, | |||
TensorIndex = tensor_index | |||
}); | |||
} | |||
config.InputLayers = model_inputs; | |||
bool _should_skip_first_node(ILayer layer) | |||
{ | |||
return layer is Functional && layer.Layers[0] is InputLayer; | |||
var model_outputs = new List<NodeConfig>(); | |||
foreach (var i in range(_output_layers.Count)) | |||
{ | |||
var (layer, node_index, tensor_index) = _output_coordinates[i]; | |||
var node_key = _make_node_key(layer.Name, node_index); | |||
if (!NetworkNodes.Contains(node_key)) | |||
continue; | |||
var new_node_index = node_conversion_map[node_key]; | |||
model_outputs.append(new NodeConfig | |||
{ | |||
Name = layer.Name, | |||
NodeIndex = new_node_index, | |||
TensorIndex = tensor_index | |||
}); | |||
} | |||
config.OutputLayers = model_outputs; | |||
return config; | |||
} | |||
string _make_node_key(string layer_name, int node_index) | |||
@@ -244,6 +244,6 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
public virtual LayerArgs get_config() | |||
=> throw new NotImplementedException(""); | |||
=> args; | |||
} | |||
} |
@@ -1,6 +1,8 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.Saving; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -10,9 +12,19 @@ namespace Tensorflow.Keras.Engine | |||
/// Serializes `Node` for Functional API's `get_config`. | |||
/// </summary> | |||
/// <returns></returns> | |||
public NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map) | |||
public List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map) | |||
{ | |||
throw new NotImplementedException(""); | |||
return KerasInputs.Select(x => { | |||
var kh = x.KerasHistory; | |||
var node_key = make_node_key(kh.Layer.Name, kh.NodeIndex); | |||
var new_node_index = node_conversion_map.Get(node_key, 0); | |||
return new NodeConfig | |||
{ | |||
Name = kh.Layer.Name, | |||
NodeIndex = new_node_index, | |||
TensorIndex = kh.TensorIndex | |||
}; | |||
}).ToList(); | |||
} | |||
} | |||
} |
@@ -5,7 +5,9 @@ using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Models; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras | |||
{ | |||
@@ -21,6 +23,7 @@ namespace Tensorflow.Keras | |||
public BackendImpl backend { get; } = new BackendImpl(); | |||
public OptimizerApi optimizers { get; } = new OptimizerApi(); | |||
public MetricsApi metrics { get; } = new MetricsApi(); | |||
public ModelsApi models { get; } = new ModelsApi(); | |||
public Sequential Sequential(List<ILayer> layers = null, | |||
string name = null) | |||
@@ -84,5 +84,10 @@ namespace Tensorflow.Keras.Layers | |||
return outputs; | |||
} | |||
public static Dense from_config(LayerArgs args) | |||
{ | |||
return new Dense(args as DenseArgs); | |||
} | |||
} | |||
} |
@@ -101,7 +101,9 @@ namespace Tensorflow.Keras.Layers | |||
tf.Context.restore_mode(); | |||
} | |||
public override LayerArgs get_config() | |||
=> args; | |||
public static InputLayer from_config(LayerArgs args) | |||
{ | |||
return new InputLayer(args as InputLayerArgs); | |||
} | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
namespace Tensorflow.Keras.Models | |||
{ | |||
public class ModelsApi | |||
{ | |||
public Functional from_config(ModelConfig config) | |||
=> Functional.from_config(config); | |||
} | |||
} |
@@ -10,11 +10,13 @@ namespace TensorFlowNET.UnitTest.Keras | |||
[TestClass] | |||
public class ModelSaveTest : EagerModeTestBase | |||
{ | |||
[TestMethod, Ignore] | |||
[TestMethod] | |||
public void GetAndFromConfig() | |||
{ | |||
var model = GetFunctionalModel(); | |||
var config = model.get_config(); | |||
var new_model = keras.models.from_config(config); | |||
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); | |||
} | |||
Functional GetFunctionalModel() | |||