Browse Source

fix: partially fix the error when saving model after loading.

tags/v0.100.5-BERT-load
AsakusaRinne Haiping 2 years ago
parent
commit
8b53eb3e5d
6 changed files with 66 additions and 12 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  2. +10
    -9
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  3. +8
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  4. +33
    -0
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  5. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
  6. +12
    -0
      test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs

+ 2
- 1
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
@@ -13,7 +14,7 @@ public static class CheckPointUtils
{
private static string _ESCAPE_CHAR = ".";
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
IDictionary<Trackable, pbc::RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();


+ 10
- 9
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -93,13 +93,14 @@ public class SaveableView
//
// }

foreach (var obj in _nodes)
{
if (obj is ConcreteFunction)
{
_concrete_functions.Add((ConcreteFunction)obj);
}
}
//_concrete_functions = new();
//foreach (var obj in _nodes)
//{
// if (obj is ConcreteFunction)
// {
// _concrete_functions.Add((ConcreteFunction)obj);
// }
//}
}

public List<ConcreteFunction> get_concrete_resource_initializers()
@@ -225,8 +226,8 @@ public class SaveableView
}
else if (obj is ConcreteFunction)
{
// TODO: complete it.
throw new NotImplementedException();
// TODO(Rinne): complete it.
// throw new NotImplementedException();
}
// skip the process of type `_CapturedTensor` and `CapturableResource`.
else


+ 8
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -17,7 +17,14 @@ namespace Tensorflow
{
protected string _name;
public virtual string Name => _handle_name;
public virtual string SharedName => _name;
public virtual string SharedName
{
get
{
// TODO(Rinne): optimize the implementation with refactor of variable.
return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1);
}
}
protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;
protected string _handle_name;


+ 33
- 0
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -152,6 +152,39 @@ namespace Tensorflow.Keras.Saving
_reconstruct_all_models();
}

/// <summary>
/// Removes tracked references that are only used when loading the model.
/// Now that the node object has been fully loaded, and the checkpoint has
/// been restored, the object no longer needs to track objects added from
/// SerializedAttributes. (Note that saving a training checkpoint still
/// functions correctly, because layers and variables are tracked
/// separately by the Layer object.)
/// </summary>
public void del_tracking()
{
foreach(var (node, _) in loaded_nodes.Values)
{
if(node is not Layer layer)
{
continue;
}
foreach(var name in PUBLIC_ATTRIBUTES.Keys)
{
layer._delete_tracking(name);
}
if(node is Functional functional)
{
foreach(var name in functional.UnconditionalDependencyNames.Keys)
{
if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success)
{
functional._delete_tracking(name);
}
}
}
}
}

private void _reconstruct_all_models()
{
HashSet<int> all_initialized_models = new();


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/load.cs View File

@@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
var loaded = Loader.load_partial(path, nodes_to_load, options);

keras_loader.finalize_objects();
// keras_loader.del_tracking();
keras_loader.del_tracking();

var model = loaded["root"];



+ 12
- 0
test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs View File

@@ -196,5 +196,17 @@ namespace Tensorflow.Keras.UnitTest.Model
// )
#endregion
}

[TestMethod]
public void SaveAfterLoad()
{
var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
model.summary();

model.save("Assets/saved_auto_compile_after_loading");

//model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading");
//model.summary();
}
}
}

Loading…
Cancel
Save