Browse Source

Merge pull request #1137 from Beacontownfc/mybranch2

fix: Fix the problem where the  model.save_weights() can't save multi-layer weight
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
a95005fe5b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+ 7
- 2
src/TensorFlowNET.Keras/Saving/hdf5_format.cs View File

@@ -7,6 +7,8 @@ using HDF5CSharp;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System.Linq;
using System.Text.RegularExpressions;

namespace Tensorflow.Keras.Saving
{
public class hdf5_format
@@ -132,7 +134,9 @@ namespace Tensorflow.Keras.Saving
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
foreach (var i_ in weight_names)
{
(success, Array result) = Hdf5.ReadDataset<float>(g, i_);
var vm = Regex.Replace(i_, "/", "$");
vm = i_.Split('/')[0] + "/$" + vm.Substring(i_.Split('/')[0].Length + 1, i_.Length - i_.Split('/')[0].Length - 1);
(success, Array result) = Hdf5.ReadDataset<float>(g, vm);
if (success)
weight_values.Add(np.array(result));
}
@@ -193,7 +197,8 @@ namespace Tensorflow.Keras.Saving
if (name.IndexOf("/") > 1)
{
var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
WriteDataset(crDataGroup, name.Split('/')[1], tensor);
var _name = Regex.Replace(name.Substring(name.Split('/')[0].Length, name.Length - name.Split('/')[0].Length), "/", "$");
WriteDataset(crDataGroup, _name, tensor);
Hdf5.CloseGroup(crDataGroup);
}
else


Loading…
Cancel
Save