|
|
@@ -152,6 +152,39 @@ namespace Tensorflow.Keras.Saving |
|
|
|
_reconstruct_all_models(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Removes tracked references that are only used when loading the model. |
|
|
|
/// Now that the node object has been fully loaded, and the checkpoint has |
|
|
|
/// been restored, the object no longer needs to track objects added from |
|
|
|
/// SerializedAttributes. (Note that saving a training checkpoint still |
|
|
|
/// functions correctly, because layers and variables are tracked |
|
|
|
/// separately by the Layer object.) |
|
|
|
/// </summary> |
|
|
|
public void del_tracking() |
|
|
|
{ |
|
|
|
foreach(var (node, _) in loaded_nodes.Values) |
|
|
|
{ |
|
|
|
if(node is not Layer layer) |
|
|
|
{ |
|
|
|
continue; |
|
|
|
} |
|
|
|
foreach(var name in PUBLIC_ATTRIBUTES.Keys) |
|
|
|
{ |
|
|
|
layer._delete_tracking(name); |
|
|
|
} |
|
|
|
if(node is Functional functional) |
|
|
|
{ |
|
|
|
foreach(var name in functional.UnconditionalDependencyNames.Keys) |
|
|
|
{ |
|
|
|
if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success) |
|
|
|
{ |
|
|
|
functional._delete_tracking(name); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private void _reconstruct_all_models() |
|
|
|
{ |
|
|
|
HashSet<int> all_initialized_models = new(); |
|
|
|