diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 18dc4426..eaa78399 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -532,5 +532,13 @@ namespace Tensorflow
var g = get_default_graph();
return g.get_name_scope();
}
+
+ public static bool executing_eagerly_outside_functions()
+ {
+ if (tf.Context.executing_eagerly())
+ return true;
+ else
+ throw new NotImplementedException("");
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index ee3eaead..b31a523a 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -142,6 +142,19 @@ namespace Tensorflow.Keras
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
}
+ public void batch_set_value(List<(IVariableV1, NDArray)> tuples)
+ {
+ if (ops.executing_eagerly_outside_functions())
+ {
+ foreach (var (x, value) in tuples)
+ x.assign(value);
+ }
+ else
+ {
+ throw new NotImplementedException("");
+ }
+ }
+
///
/// Pads the 2nd and 3rd dimensions of a 4D tensor.
///
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs
index 4be5cc0d..2ba215e8 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs
@@ -9,48 +9,28 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
- private long fileId = -1;
- private long f = -1;
- public void load_weights(string filepath ="",bool by_name= false, bool skip_mismatch=false, object options = null)
+ public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
- long root = Hdf5.OpenFile(filepath, true);
+ long fileId = Hdf5.OpenFile(filepath, true);
- long fileId = root;
- //try
- //{
+ bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
+ bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");
- bool msuccess = Hdf5.GroupExists(fileId, "model_weights");
- bool lsuccess = Hdf5.GroupExists(fileId, "layer_names");
-
- if (!lsuccess && msuccess)
- {
- f = H5G.open(fileId, "model_weights");
-
- }
- if (by_name)
- {
- //fdf5_format.load_weights_from_hdf5_group_by_name();
- }
- else
- {
- fdf5_format.load_weights_from_hdf5_group(f, this);
- }
- H5G.close(f);
- //}
- //catch (Exception ex)
- //{
- // if (fileId != -1)
- // {
- // Hdf5.CloseFile(fileId);
- // }
- // if (f != -1)
- // {
- // H5G.close(f);
- // }
- // throw new Exception(ex.ToString());
- //}
+ if (!lsuccess && msuccess)
+ {
+ fileId = H5G.open(fileId, "model_weights");
+ }
+ if (by_name)
+ {
+ //fdf5_format.load_weights_from_hdf5_group_by_name();
+ throw new NotImplementedException("");
+ }
+ else
+ {
+ fdf5_format.load_weights_from_hdf5_group(fileId, Layers);
+ }
+ H5G.close(fileId);
}
-
}
}
diff --git a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs b/src/TensorFlowNET.Keras/Saving/fdf5_format.cs
index 3a9d2438..a2a9e537 100644
--- a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs
+++ b/src/TensorFlowNET.Keras/Saving/fdf5_format.cs
@@ -7,6 +7,8 @@ using Tensorflow.Keras.Engine;
using HDF5CSharp;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
+using System.Linq;
+
namespace Tensorflow.Keras.Saving
{
public class fdf5_format
@@ -45,13 +47,29 @@ namespace Tensorflow.Keras.Saving
{
}
- public static void preprocess_weights_for_loading(long filepath = -1, Dictionary custom_objects = null, bool compile = false)
- {
- }
- public static void _convert_rnn_weights(long filepath = -1, Dictionary custom_objects = null, bool compile = false)
+ ///
+ /// Preprocess layer weights between different Keras formats.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static List preprocess_weights_for_loading(ILayer layer, List weights, string original_keras_version = null, string original_backend = null)
{
+ // convert CuDNN layers
+ return _convert_rnn_weights(layer, weights);
+ }
+ ///
+ /// Converts weights for RNN layers between native and CuDNN format.
+ ///
+ ///
+ ///
+ static List _convert_rnn_weights(ILayer layer, List weights)
+ {
+ var target_class = layer.GetType().Name;
+ return weights;
}
public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false)
{
@@ -65,56 +83,79 @@ namespace Tensorflow.Keras.Saving
{
}
- public static void load_weights_from_hdf5_group(long f=-1,Model model=null)
+ public static void load_weights_from_hdf5_group(long f, List layers)
{
- string original_keras_version = "1";
+ string original_keras_version = "2.4.0";
string original_backend = null;
if (Hdf5.AttributeExists(f, "keras_version"))
{
- (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
+ var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
if (success)
- {
- original_keras_version = attr[0];
- }
+ original_keras_version = attr.First();
+ // keras version should be 2.5.0+
+ var ver_major = int.Parse(original_keras_version.Split('.')[0]);
+ var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
+ if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
+ throw new ValueError("keras version should be 2.5.0 or later.");
}
if (Hdf5.AttributeExists(f, "backend"))
{
- (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", "");
+ var (success, attr) = Hdf5.ReadStringAttributes(f, "backend", "");
if (success)
- {
- original_backend = attr[0];
- }
+ original_backend = attr.First();
}
List filtered_layers = new List();
- List weights;
- foreach (var layer in model.Layers)
+ List weights;
+ foreach (var layer in layers)
{
weights = _legacy_weights(layer);
- if (weights.Count>0)
+ if (weights.Count > 0)
{
filtered_layers.append(layer);
}
}
- string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names");
- List weight_values=new List();
- foreach (var i in filtered_layers) {
- long g = H5G.open(f, i.Name);
- string[] weight_names = null;
- if (g != -1)
- {
- weight_names = load_attributes_from_hdf5_group(g, "weight_names");
- }
- if (weight_names != null)
+ string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names");
+ var filtered_layer_names = new List();
+ foreach(var name in layer_names)
+ {
+ long g = H5G.open(f, name);
+ var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
+ if (weight_names.Count() > 0)
+ filtered_layer_names.Add(name);
+ H5G.close(g);
+ }
+ layer_names = filtered_layer_names.ToArray();
+ if (layer_names.Length != filtered_layers.Count())
+ throw new ValueError("You are trying to load a weight file " +
+ $"containing {layer_names}" +
+ $" layers into a model with {filtered_layers.Count} layers.");
+
+ var weight_value_tuples = new List<(IVariableV1, NDArray)>();
+ foreach (var (k, name) in enumerate(layer_names))
+ {
+ var weight_values = new List();
+ long g = H5G.open(f, name);
+ var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
+ foreach (var i_ in weight_names)
{
- foreach (var i_ in weight_names) {
- (bool success, Array result) = Hdf5.ReadDataset(g, i_);
- //
+ (bool success, Array result) = Hdf5.ReadDataset(g, i_);
+ if (success)
weight_values.Add(np.array(result));
- }
}
H5G.close(g);
+ var layer = filtered_layers[k];
+ var symbolic_weights = _legacy_weights(layer);
+ preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend);
+ if (weight_values.Count() != symbolic_weights.Count())
+ throw new ValueError($"Layer #{k} (named {layer.Name}" +
+ "in the current model) was found to " +
+ $"correspond to layer {name} in the save file." +
+ $"However the new layer {layer.Name} expects " +
+ $"{symbolic_weights.Count()} weights, but the saved weights have " +
+ $"{weight_values.Count()} elements.");
+ weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
}
-
+ keras.backend.batch_set_value(weight_value_tuples);
}
public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false)
{
@@ -128,15 +169,13 @@ namespace Tensorflow.Keras.Saving
{
}
- public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "")
+ public static string[] load_attributes_from_hdf5_group(long group, string name)
{
- if (Hdf5.AttributeExists(f, name))
+ if (Hdf5.AttributeExists(group, name))
{
- (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, "");
+ var (success, attr) = Hdf5.ReadStringAttributes(group, name, "");
if (success)
- {
- return attr;
- }
+ return attr.ToArray();
}
return null;
}
@@ -145,33 +184,10 @@ namespace Tensorflow.Keras.Saving
}
- public static List _legacy_weights(ILayer layer)
+ public static List _legacy_weights(ILayer layer)
{
-
- List weights= new List();
- if (layer.trainable_weights.Count != 0)
- {
- Tensor[] trainable_weights = Array.ConvertAll(layer.trainable_weights.ToArray(), s => s.AsTensor());
- Tensor[] non_trainable_weights =null;
- if (layer.non_trainable_weights.Count != 0)
- {
- non_trainable_weights = Array.ConvertAll(layer.non_trainable_weights.ToArray(), s => s.AsTensor());
- }
- foreach (var i in trainable_weights) {
- if (non_trainable_weights != null)
- {
- foreach (var i_ in non_trainable_weights)
- {
- weights.Add(i + i_);
- }
- }
- else {
- weights.Add(i);
- };
-
-
- }
- }
+ var weights = layer.trainable_weights.Select(x => x).ToList();
+ weights.AddRange(layer.non_trainable_weights);
return weights;
}
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index dbbfc19e..e4864e1d 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -46,7 +46,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
-