diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 18dc4426..eaa78399 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -532,5 +532,13 @@ namespace Tensorflow var g = get_default_graph(); return g.get_name_scope(); } + + public static bool executing_eagerly_outside_functions() + { + if (tf.Context.executing_eagerly()) + return true; + else + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index ee3eaead..b31a523a 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -142,6 +142,19 @@ namespace Tensorflow.Keras _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); } + public void batch_set_value(List<(IVariableV1, NDArray)> tuples) + { + if (ops.executing_eagerly_outside_functions()) + { + foreach (var (x, value) in tuples) + x.assign(value); + } + else + { + throw new NotImplementedException(""); + } + } + /// /// Pads the 2nd and 3rd dimensions of a 4D tensor. /// diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs index 4be5cc0d..2ba215e8 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -9,48 +9,28 @@ namespace Tensorflow.Keras.Engine { public partial class Model { - private long fileId = -1; - private long f = -1; - public void load_weights(string filepath ="",bool by_name= false, bool skip_mismatch=false, object options = null) + public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) { - long root = Hdf5.OpenFile(filepath, true); + long fileId = Hdf5.OpenFile(filepath, true); - long fileId = root; - //try - //{ + bool msuccess = Hdf5.GroupExists(fileId, "model_weights"); + bool lsuccess = Hdf5.GroupExists(fileId, "layer_names"); - bool msuccess = Hdf5.GroupExists(fileId, "model_weights"); - bool lsuccess = Hdf5.GroupExists(fileId, "layer_names"); - - if (!lsuccess && msuccess) - { - f = H5G.open(fileId, "model_weights"); - - } - if (by_name) - { - //fdf5_format.load_weights_from_hdf5_group_by_name(); - } - else - { - fdf5_format.load_weights_from_hdf5_group(f, this); - } - H5G.close(f); - //} - //catch (Exception ex) - //{ - // if (fileId != -1) - // { - // Hdf5.CloseFile(fileId); - // } - // if (f != -1) - // { - // H5G.close(f); - // } - // throw new Exception(ex.ToString()); - //} + if (!lsuccess && msuccess) + { + fileId = H5G.open(fileId, "model_weights"); + } + if (by_name) + { + //fdf5_format.load_weights_from_hdf5_group_by_name(); + throw new NotImplementedException(""); + } + else + { + fdf5_format.load_weights_from_hdf5_group(fileId, Layers); + } + H5G.close(fileId); } - } } diff --git a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs b/src/TensorFlowNET.Keras/Saving/fdf5_format.cs index 3a9d2438..a2a9e537 100644 --- a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/fdf5_format.cs @@ -7,6 +7,8 @@ using Tensorflow.Keras.Engine; using HDF5CSharp; using static Tensorflow.Binding; using static Tensorflow.KerasApi; +using System.Linq; + namespace Tensorflow.Keras.Saving { public class fdf5_format @@ -45,13 +47,29 @@ namespace Tensorflow.Keras.Saving { } - public static void preprocess_weights_for_loading(long filepath = -1, Dictionary custom_objects = null, bool compile = false) - { - } - public static void _convert_rnn_weights(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + /// + /// Preprocess layer weights between different Keras formats. + /// + /// + /// + /// + /// + public static List preprocess_weights_for_loading(ILayer layer, List weights, string original_keras_version = null, string original_backend = null) { + // convert CuDNN layers + return _convert_rnn_weights(layer, weights); + } + /// + /// Converts weights for RNN layers between native and CuDNN format. + /// + /// + /// + static List _convert_rnn_weights(ILayer layer, List weights) + { + var target_class = layer.GetType().Name; + return weights; } public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) { @@ -65,56 +83,79 @@ namespace Tensorflow.Keras.Saving { } - public static void load_weights_from_hdf5_group(long f=-1,Model model=null) + public static void load_weights_from_hdf5_group(long f, List layers) { - string original_keras_version = "1"; + string original_keras_version = "2.4.0"; string original_backend = null; if (Hdf5.AttributeExists(f, "keras_version")) { - (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", ""); + var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", ""); if (success) - { - original_keras_version = attr[0]; - } + original_keras_version = attr.First(); + // keras version should be 2.5.0+ + var ver_major = int.Parse(original_keras_version.Split('.')[0]); + var ver_minor = int.Parse(original_keras_version.Split('.')[1]); + if (ver_major < 2 || (ver_major == 2 && ver_minor < 5)) + throw new ValueError("keras version should be 2.5.0 or later."); } if (Hdf5.AttributeExists(f, "backend")) { - (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", ""); + var (success, attr) = Hdf5.ReadStringAttributes(f, "backend", ""); if (success) - { - original_backend = attr[0]; - } + original_backend = attr.First(); } List filtered_layers = new List(); - List weights; - foreach (var layer in model.Layers) + List weights; + foreach (var layer in layers) { weights = _legacy_weights(layer); - if (weights.Count>0) + if (weights.Count > 0) { filtered_layers.append(layer); } } - string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names"); - List weight_values=new List(); - foreach (var i in filtered_layers) { - long g = H5G.open(f, i.Name); - string[] weight_names = null; - if (g != -1) - { - weight_names = load_attributes_from_hdf5_group(g, "weight_names"); - } - if (weight_names != null) + string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names"); + var filtered_layer_names = new List(); + foreach(var name in layer_names) + { + long g = H5G.open(f, name); + var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); + if (weight_names.Count() > 0) + filtered_layer_names.Add(name); + H5G.close(g); + } + layer_names = filtered_layer_names.ToArray(); + if (layer_names.Length != filtered_layers.Count()) + throw new ValueError("You are trying to load a weight file " + + $"containing {layer_names}" + + $" layers into a model with {filtered_layers.Count} layers."); + + var weight_value_tuples = new List<(IVariableV1, NDArray)>(); + foreach (var (k, name) in enumerate(layer_names)) + { + var weight_values = new List(); + long g = H5G.open(f, name); + var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); + foreach (var i_ in weight_names) { - foreach (var i_ in weight_names) { - (bool success, Array result) = Hdf5.ReadDataset(g, i_); - // + (bool success, Array result) = Hdf5.ReadDataset(g, i_); + if (success) weight_values.Add(np.array(result)); - } } H5G.close(g); + var layer = filtered_layers[k]; + var symbolic_weights = _legacy_weights(layer); + preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend); + if (weight_values.Count() != symbolic_weights.Count()) + throw new ValueError($"Layer #{k} (named {layer.Name}" + + "in the current model) was found to " + + $"correspond to layer {name} in the save file." + + $"However the new layer {layer.Name} expects " + + $"{symbolic_weights.Count()} weights, but the saved weights have " + + $"{weight_values.Count()} elements."); + weight_value_tuples.AddRange(zip(symbolic_weights, weight_values)); } - + keras.backend.batch_set_value(weight_value_tuples); } public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false) { @@ -128,15 +169,13 @@ namespace Tensorflow.Keras.Saving { } - public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "") + public static string[] load_attributes_from_hdf5_group(long group, string name) { - if (Hdf5.AttributeExists(f, name)) + if (Hdf5.AttributeExists(group, name)) { - (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, ""); + var (success, attr) = Hdf5.ReadStringAttributes(group, name, ""); if (success) - { - return attr; - } + return attr.ToArray(); } return null; } @@ -145,33 +184,10 @@ namespace Tensorflow.Keras.Saving } - public static List _legacy_weights(ILayer layer) + public static List _legacy_weights(ILayer layer) { - - List weights= new List(); - if (layer.trainable_weights.Count != 0) - { - Tensor[] trainable_weights = Array.ConvertAll(layer.trainable_weights.ToArray(), s => s.AsTensor()); - Tensor[] non_trainable_weights =null; - if (layer.non_trainable_weights.Count != 0) - { - non_trainable_weights = Array.ConvertAll(layer.non_trainable_weights.ToArray(), s => s.AsTensor()); - } - foreach (var i in trainable_weights) { - if (non_trainable_weights != null) - { - foreach (var i_ in non_trainable_weights) - { - weights.Add(i + i_); - } - } - else { - weights.Add(i); - }; - - - } - } + var weights = layer.trainable_weights.Select(x => x).ToList(); + weights.AddRange(layer.non_trainable_weights); return weights; } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index dbbfc19e..e4864e1d 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -46,7 +46,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac -