Browse Source

fix h5 saving.

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
4b4f7f8e76
1 changed files with 3 additions and 4 deletions
  1. +3
    -4
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+ 3
- 4
src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -179,7 +179,6 @@ namespace Tensorflow.Keras.Saving
Hdf5.WriteAttribute(f, "backend", "tensorflow"); Hdf5.WriteAttribute(f, "backend", "tensorflow");
Hdf5.WriteAttribute(f, "keras_version", "2.5.0"); Hdf5.WriteAttribute(f, "keras_version", "2.5.0");


long g = 0, crDataGroup=0;
foreach (var layer in layers) foreach (var layer in layers)
{ {
var weights = _legacy_weights(layer); var weights = _legacy_weights(layer);
@@ -191,20 +190,20 @@ namespace Tensorflow.Keras.Saving
foreach (var weight in weights) foreach (var weight in weights)
weight_names.Add(weight.Name); weight_names.Add(weight.Name);
g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
var g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray()); save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray());
foreach (var (name, val) in zip(weight_names, weights)) foreach (var (name, val) in zip(weight_names, weights))
{ {
var tensor = val.AsTensor(); var tensor = val.AsTensor();
if (name.IndexOf("/") > 1) if (name.IndexOf("/") > 1)
{ {
crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
WriteDataset(crDataGroup, name.Split('/')[1], tensor); WriteDataset(crDataGroup, name.Split('/')[1], tensor);
Hdf5.CloseGroup(crDataGroup); Hdf5.CloseGroup(crDataGroup);
} }
else else
{ {
WriteDataset(crDataGroup, name, tensor);
WriteDataset(g, name, tensor);
} }
} }
Hdf5.CloseGroup(g); Hdf5.CloseGroup(g);


Loading…
Cancel
Save