|
|
@@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine |
|
|
|
{ |
|
|
|
public partial class Model |
|
|
|
{ |
|
|
|
static Dictionary<string, List<(string, NDArray)>> weightsCache |
|
|
|
= new Dictionary<string, List<(string, NDArray)>>(); |
|
|
|
|
|
|
|
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) |
|
|
|
{ |
|
|
|
// Get from cache |
|
|
|
if (weightsCache.ContainsKey(filepath)) |
|
|
|
{ |
|
|
|
var filtered_layers = new List<ILayer>(); |
|
|
|
foreach (var layer in Layers) |
|
|
|
{ |
|
|
|
var weights = hdf5_format._legacy_weights(layer); |
|
|
|
if (weights.Count > 0) |
|
|
|
filtered_layers.append(layer); |
|
|
|
} |
|
|
|
|
|
|
|
var weight_value_tuples = new List<(IVariableV1, NDArray)>(); |
|
|
|
filtered_layers.Select((layer, i) => |
|
|
|
{ |
|
|
|
var symbolic_weights = hdf5_format._legacy_weights(layer); |
|
|
|
foreach(var weight in symbolic_weights) |
|
|
|
{ |
|
|
|
var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2; |
|
|
|
weight_value_tuples.Add((weight, weight_value)); |
|
|
|
} |
|
|
|
return layer; |
|
|
|
}).ToList(); |
|
|
|
|
|
|
|
keras.backend.batch_set_value(weight_value_tuples); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
long fileId = Hdf5.OpenFile(filepath, true); |
|
|
|
if(fileId < 0) |
|
|
|
{ |
|
|
@@ -29,8 +59,11 @@ namespace Tensorflow.Keras.Engine |
|
|
|
throw new NotImplementedException(""); |
|
|
|
else |
|
|
|
{ |
|
|
|
hdf5_format.load_weights_from_hdf5_group(fileId, Layers); |
|
|
|
var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); |
|
|
|
Hdf5.CloseFile(fileId); |
|
|
|
|
|
|
|
weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList(); |
|
|
|
keras.backend.batch_set_value(weight_value_tuples); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|