|
|
@@ -7,6 +7,8 @@ using Tensorflow.Keras.Engine; |
|
|
|
using HDF5CSharp; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
using static Tensorflow.KerasApi; |
|
|
|
using System.Linq; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Saving |
|
|
|
{ |
|
|
|
public class fdf5_format |
|
|
@@ -45,13 +47,29 @@ namespace Tensorflow.Keras.Saving |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
public static void preprocess_weights_for_loading(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
public static void _convert_rnn_weights(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) |
|
|
|
/// <summary> |
|
|
|
/// Preprocess layer weights between different Keras formats. |
|
|
|
/// </summary> |
|
|
|
/// <param name="layer"></param> |
|
|
|
/// <param name="weights"></param> |
|
|
|
/// <param name="original_keras_version"></param> |
|
|
|
/// <param name="original_backend"></param> |
|
|
|
public static List<NDArray> preprocess_weights_for_loading(ILayer layer, List<NDArray> weights, string original_keras_version = null, string original_backend = null) |
|
|
|
{ |
|
|
|
// convert CuDNN layers |
|
|
|
return _convert_rnn_weights(layer, weights); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Converts weights for RNN layers between native and CuDNN format. |
|
|
|
/// </summary> |
|
|
|
/// <param name="layer"></param> |
|
|
|
/// <param name="weights"></param> |
|
|
|
static List<NDArray> _convert_rnn_weights(ILayer layer, List<NDArray> weights) |
|
|
|
{ |
|
|
|
var target_class = layer.GetType().Name; |
|
|
|
return weights; |
|
|
|
} |
|
|
|
public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) |
|
|
|
{ |
|
|
@@ -65,56 +83,79 @@ namespace Tensorflow.Keras.Saving |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
public static void load_weights_from_hdf5_group(long f=-1,Model model=null) |
|
|
|
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers) |
|
|
|
{ |
|
|
|
string original_keras_version = "1"; |
|
|
|
string original_keras_version = "2.4.0"; |
|
|
|
string original_backend = null; |
|
|
|
if (Hdf5.AttributeExists(f, "keras_version")) |
|
|
|
{ |
|
|
|
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", ""); |
|
|
|
var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", ""); |
|
|
|
if (success) |
|
|
|
{ |
|
|
|
original_keras_version = attr[0]; |
|
|
|
} |
|
|
|
original_keras_version = attr.First(); |
|
|
|
// keras version should be 2.5.0+ |
|
|
|
var ver_major = int.Parse(original_keras_version.Split('.')[0]); |
|
|
|
var ver_minor = int.Parse(original_keras_version.Split('.')[1]); |
|
|
|
if (ver_major < 2 || (ver_major == 2 && ver_minor < 5)) |
|
|
|
throw new ValueError("keras version should be 2.5.0 or later."); |
|
|
|
} |
|
|
|
if (Hdf5.AttributeExists(f, "backend")) |
|
|
|
{ |
|
|
|
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", ""); |
|
|
|
var (success, attr) = Hdf5.ReadStringAttributes(f, "backend", ""); |
|
|
|
if (success) |
|
|
|
{ |
|
|
|
original_backend = attr[0]; |
|
|
|
} |
|
|
|
original_backend = attr.First(); |
|
|
|
} |
|
|
|
List<ILayer> filtered_layers = new List<ILayer>(); |
|
|
|
List<Tensor> weights; |
|
|
|
foreach (var layer in model.Layers) |
|
|
|
List<IVariableV1> weights; |
|
|
|
foreach (var layer in layers) |
|
|
|
{ |
|
|
|
weights = _legacy_weights(layer); |
|
|
|
if (weights.Count>0) |
|
|
|
if (weights.Count > 0) |
|
|
|
{ |
|
|
|
filtered_layers.append(layer); |
|
|
|
} |
|
|
|
} |
|
|
|
string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names"); |
|
|
|
List<NDArray> weight_values=new List<NDArray>(); |
|
|
|
foreach (var i in filtered_layers) { |
|
|
|
long g = H5G.open(f, i.Name); |
|
|
|
string[] weight_names = null; |
|
|
|
if (g != -1) |
|
|
|
{ |
|
|
|
weight_names = load_attributes_from_hdf5_group(g, "weight_names"); |
|
|
|
} |
|
|
|
if (weight_names != null) |
|
|
|
string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names"); |
|
|
|
var filtered_layer_names = new List<string>(); |
|
|
|
foreach(var name in layer_names) |
|
|
|
{ |
|
|
|
long g = H5G.open(f, name); |
|
|
|
var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); |
|
|
|
if (weight_names.Count() > 0) |
|
|
|
filtered_layer_names.Add(name); |
|
|
|
H5G.close(g); |
|
|
|
} |
|
|
|
layer_names = filtered_layer_names.ToArray(); |
|
|
|
if (layer_names.Length != filtered_layers.Count()) |
|
|
|
throw new ValueError("You are trying to load a weight file " + |
|
|
|
$"containing {layer_names}" + |
|
|
|
$" layers into a model with {filtered_layers.Count} layers."); |
|
|
|
|
|
|
|
var weight_value_tuples = new List<(IVariableV1, NDArray)>(); |
|
|
|
foreach (var (k, name) in enumerate(layer_names)) |
|
|
|
{ |
|
|
|
var weight_values = new List<NDArray>(); |
|
|
|
long g = H5G.open(f, name); |
|
|
|
var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); |
|
|
|
foreach (var i_ in weight_names) |
|
|
|
{ |
|
|
|
foreach (var i_ in weight_names) { |
|
|
|
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_); |
|
|
|
// |
|
|
|
(bool success, Array result) = Hdf5.ReadDataset<float>(g, i_); |
|
|
|
if (success) |
|
|
|
weight_values.Add(np.array(result)); |
|
|
|
} |
|
|
|
} |
|
|
|
H5G.close(g); |
|
|
|
var layer = filtered_layers[k]; |
|
|
|
var symbolic_weights = _legacy_weights(layer); |
|
|
|
preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend); |
|
|
|
if (weight_values.Count() != symbolic_weights.Count()) |
|
|
|
throw new ValueError($"Layer #{k} (named {layer.Name}" + |
|
|
|
"in the current model) was found to " + |
|
|
|
$"correspond to layer {name} in the save file." + |
|
|
|
$"However the new layer {layer.Name} expects " + |
|
|
|
$"{symbolic_weights.Count()} weights, but the saved weights have " + |
|
|
|
$"{weight_values.Count()} elements."); |
|
|
|
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values)); |
|
|
|
} |
|
|
|
|
|
|
|
keras.backend.batch_set_value(weight_value_tuples); |
|
|
|
} |
|
|
|
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) |
|
|
|
{ |
|
|
@@ -128,15 +169,13 @@ namespace Tensorflow.Keras.Saving |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "") |
|
|
|
public static string[] load_attributes_from_hdf5_group(long group, string name) |
|
|
|
{ |
|
|
|
if (Hdf5.AttributeExists(f, name)) |
|
|
|
if (Hdf5.AttributeExists(group, name)) |
|
|
|
{ |
|
|
|
(bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, ""); |
|
|
|
var (success, attr) = Hdf5.ReadStringAttributes(group, name, ""); |
|
|
|
if (success) |
|
|
|
{ |
|
|
|
return attr; |
|
|
|
} |
|
|
|
return attr.ToArray(); |
|
|
|
} |
|
|
|
return null; |
|
|
|
} |
|
|
@@ -145,33 +184,10 @@ namespace Tensorflow.Keras.Saving |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
public static List<Tensor> _legacy_weights(ILayer layer) |
|
|
|
public static List<IVariableV1> _legacy_weights(ILayer layer) |
|
|
|
{ |
|
|
|
|
|
|
|
List<Tensor> weights= new List<Tensor>(); |
|
|
|
if (layer.trainable_weights.Count != 0) |
|
|
|
{ |
|
|
|
Tensor[] trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.trainable_weights.ToArray(), s => s.AsTensor()); |
|
|
|
Tensor[] non_trainable_weights =null; |
|
|
|
if (layer.non_trainable_weights.Count != 0) |
|
|
|
{ |
|
|
|
non_trainable_weights = Array.ConvertAll<IVariableV1, Tensor>(layer.non_trainable_weights.ToArray(), s => s.AsTensor()); |
|
|
|
} |
|
|
|
foreach (var i in trainable_weights) { |
|
|
|
if (non_trainable_weights != null) |
|
|
|
{ |
|
|
|
foreach (var i_ in non_trainable_weights) |
|
|
|
{ |
|
|
|
weights.Add(i + i_); |
|
|
|
} |
|
|
|
} |
|
|
|
else { |
|
|
|
weights.Add(i); |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
var weights = layer.trainable_weights.Select(x => x).ToList(); |
|
|
|
weights.AddRange(layer.non_trainable_weights); |
|
|
|
return weights; |
|
|
|
} |
|
|
|
} |
|
|
|