Browse Source

Allow Model to cache weights.

tags/v0.150.0-BERT-Model
hchen 2 years ago
parent
commit
0f02885dfb
2 changed files with 36 additions and 3 deletions
  1. +34
    -1
      src/TensorFlowNET.Keras/Engine/Model.Training.cs
  2. +2
    -2
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+ 34
- 1
src/TensorFlowNET.Keras/Engine/Model.Training.cs View File

@@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
static Dictionary<string, List<(string, NDArray)>> weightsCache
= new Dictionary<string, List<(string, NDArray)>>();

public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
// Get from cache
if (weightsCache.ContainsKey(filepath))
{
var filtered_layers = new List<ILayer>();
foreach (var layer in Layers)
{
var weights = hdf5_format._legacy_weights(layer);
if (weights.Count > 0)
filtered_layers.append(layer);
}

var weight_value_tuples = new List<(IVariableV1, NDArray)>();
filtered_layers.Select((layer, i) =>
{
var symbolic_weights = hdf5_format._legacy_weights(layer);
foreach(var weight in symbolic_weights)
{
var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2;
weight_value_tuples.Add((weight, weight_value));
}
return layer;
}).ToList();

keras.backend.batch_set_value(weight_value_tuples);
return;
}

long fileId = Hdf5.OpenFile(filepath, true);
if(fileId < 0)
{
@@ -29,8 +59,11 @@ namespace Tensorflow.Keras.Engine
throw new NotImplementedException("");
else
{
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
Hdf5.CloseFile(fileId);

weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList();
keras.backend.batch_set_value(weight_value_tuples);
}
}



+ 2
- 2
src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving

}

public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers)
{
string original_keras_version = "2.5.0";
string original_backend = null;
@@ -152,7 +152,7 @@ namespace Tensorflow.Keras.Saving
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
}

keras.backend.batch_set_value(weight_value_tuples);
return weight_value_tuples;
}

public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)


Loading…
Cancel
Save