@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
@@ -13,7 +14,7 @@ public static class CheckPointUtils | |||||
{ | { | ||||
private static string _ESCAPE_CHAR = "."; | private static string _ESCAPE_CHAR = "."; | ||||
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | ||||
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
IDictionary<Trackable, pbc::RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | ||||
{ | { | ||||
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | ||||
@@ -93,13 +93,14 @@ public class SaveableView | |||||
// | // | ||||
// } | // } | ||||
foreach (var obj in _nodes) | |||||
{ | |||||
if (obj is ConcreteFunction) | |||||
{ | |||||
_concrete_functions.Add((ConcreteFunction)obj); | |||||
} | |||||
} | |||||
//_concrete_functions = new(); | |||||
//foreach (var obj in _nodes) | |||||
//{ | |||||
// if (obj is ConcreteFunction) | |||||
// { | |||||
// _concrete_functions.Add((ConcreteFunction)obj); | |||||
// } | |||||
//} | |||||
} | } | ||||
public List<ConcreteFunction> get_concrete_resource_initializers() | public List<ConcreteFunction> get_concrete_resource_initializers() | ||||
@@ -225,8 +226,8 @@ public class SaveableView | |||||
} | } | ||||
else if (obj is ConcreteFunction) | else if (obj is ConcreteFunction) | ||||
{ | { | ||||
// TODO: complete it. | |||||
throw new NotImplementedException(); | |||||
// TODO(Rinne): complete it. | |||||
// throw new NotImplementedException(); | |||||
} | } | ||||
// skip the process of type `_CapturedTensor` and `CapturableResource`. | // skip the process of type `_CapturedTensor` and `CapturableResource`. | ||||
else | else | ||||
@@ -17,7 +17,14 @@ namespace Tensorflow | |||||
{ | { | ||||
protected string _name; | protected string _name; | ||||
public virtual string Name => _handle_name; | public virtual string Name => _handle_name; | ||||
public virtual string SharedName => _name; | |||||
public virtual string SharedName | |||||
{ | |||||
get | |||||
{ | |||||
// TODO(Rinne): optimize the implementation with refactor of variable. | |||||
return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1); | |||||
} | |||||
} | |||||
protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
protected string _handle_name; | protected string _handle_name; | ||||
@@ -152,6 +152,39 @@ namespace Tensorflow.Keras.Saving | |||||
_reconstruct_all_models(); | _reconstruct_all_models(); | ||||
} | } | ||||
/// <summary> | |||||
/// Removes tracked references that are only used when loading the model. | |||||
/// Now that the node object has been fully loaded, and the checkpoint has | |||||
/// been restored, the object no longer needs to track objects added from | |||||
/// SerializedAttributes. (Note that saving a training checkpoint still | |||||
/// functions correctly, because layers and variables are tracked | |||||
/// separately by the Layer object.) | |||||
/// </summary> | |||||
public void del_tracking() | |||||
{ | |||||
foreach(var (node, _) in loaded_nodes.Values) | |||||
{ | |||||
if(node is not Layer layer) | |||||
{ | |||||
continue; | |||||
} | |||||
foreach(var name in PUBLIC_ATTRIBUTES.Keys) | |||||
{ | |||||
layer._delete_tracking(name); | |||||
} | |||||
if(node is Functional functional) | |||||
{ | |||||
foreach(var name in functional.UnconditionalDependencyNames.Keys) | |||||
{ | |||||
if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success) | |||||
{ | |||||
functional._delete_tracking(name); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private void _reconstruct_all_models() | private void _reconstruct_all_models() | ||||
{ | { | ||||
HashSet<int> all_initialized_models = new(); | HashSet<int> all_initialized_models = new(); | ||||
@@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
var loaded = Loader.load_partial(path, nodes_to_load, options); | var loaded = Loader.load_partial(path, nodes_to_load, options); | ||||
keras_loader.finalize_objects(); | keras_loader.finalize_objects(); | ||||
// keras_loader.del_tracking(); | |||||
keras_loader.del_tracking(); | |||||
var model = loaded["root"]; | var model = loaded["root"]; | ||||
@@ -196,5 +196,17 @@ namespace Tensorflow.Keras.UnitTest.Model | |||||
// ) | // ) | ||||
#endregion | #endregion | ||||
} | } | ||||
[TestMethod] | |||||
public void SaveAfterLoad() | |||||
{ | |||||
var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile"); | |||||
model.summary(); | |||||
model.save("Assets/saved_auto_compile_after_loading"); | |||||
//model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading"); | |||||
//model.summary(); | |||||
} | |||||
} | } | ||||
} | } |