diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs new file mode 100644 index 00000000..dc2a92fb --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -0,0 +1,253 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint +{ + internal record class TrackableData( + // A trackable in the root Trackable object graph. + Trackable trackable, + // The index at which the Trackable appears in TrackableObjectGraph.nodes. + int node_id, + // The BFS-generated path from the root object / used to generate readable checkpoint keys. + string object_name, + // A list of ObjectReference for each child connected to this Trackable. + pbc::RepeatedField children_proto, + // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). + pbc::RepeatedField slot_variable_proto, + // The object to save to checkpoint. Usually this is the same as `trackable`, + // but can differ when the the caller wants to specify a different object to + // save. For example, when saving checkpoints asynchronously, variables are + // copied to the CPU. `object_to_save` is set as the copied variable. + Trackable object_to_save + ); + public static class SaveUtil + { + public static (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) + { + var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); + var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); + + var object_graph_proto = fill_object_graph_proto(trackable_data); + + var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); + var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); + + Dictionary feed_additions; + if(cache is null) + { + feed_additions = null; + serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, + cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); + } + else + { + feed_additions = null; + // TODO: deal with cache. + throw new NotFiniteNumberException(); + } + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + + return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); + } + + private static (List, Dictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach(var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + Dictionary node_ids = new(); + for(int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + List trackable_data = new(); + foreach(var trackable in trackable_objects) + { + pbc::RepeatedField children_proto = new(); + foreach(var child in graph_view.list_children(trackable)) + { + children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { + NodeId = node_ids[child.Refer], + LocalName = child.Name + }); + } + slot_variables.TryGetValue(trackable, out var slot_variable); + trackable_data.Add(new TrackableData( + trackable: trackable, + node_id: node_ids[trackable], + object_name: object_names[trackable], + children_proto: children_proto, + slot_variable_proto: slot_variable??new pbc.RepeatedField(), + object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) + )); + } + return (trackable_data, node_ids); + } + + private static TrackableObjectGraph fill_object_graph_proto(IList trackable_data) + { + TrackableObjectGraph object_graph_proto = new(); + for(int i = 0; i < trackable_data.Count; i++) + { + var td = trackable_data[i]; + Debug.Assert(td.node_id == i); + object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); + } + return object_graph_proto; + } + + /// + /// Creates dictionary of tensors to checkpoint, and updates the proto. + /// + /// + /// + /// + /// + /// + private static IDictionary> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) + { + Dictionary> serialized_tensors = new(); + foreach(var td in tensor_trackables) + { + // TODO: deal with cache. + var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; + var trackable = td.object_to_save; + IDictionary tensor_dict; + if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) + { + (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); + } + else + { + tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + } + if(trackable is not null) + { + serialized_tensors[trackable] = tensor_dict; + } + else + { + serialized_tensors[Trackable.None] = tensor_dict; + } + } + return serialized_tensors; + } + + private static IDictionary get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + var trackable = trackable_data.object_to_save; + + // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. + IDictionary ret_tensor_dict; + if (call_with_mapped_captures) + { + throw new NotImplementedException(); + } + else + { + ret_tensor_dict = trackable.serialize_to_tensors(); + } + + // TODO: revise the types and complete it + Dictionary tensor_dict = new(); + foreach(var pair in ret_tensor_dict) + { + var local_name = TrackableUtils.escape_local_name(pair.Key); + var maybe_tensor = pair.Value; + var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); + + tensor_dict[checkpoint_key] = maybe_tensor; + + if(maybe_tensor is SaveSpec) + { + ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + } + + if(object_graph_proto is not null) + { + object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { + Name = local_name, + CheckpointKey = checkpoint_key, + FullName = CheckPointUtils.get_full_name(trackable) + }); + } + } + return tensor_dict; + } + + /// + /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. + /// + /// + /// + /// + /// + /// + private static (Trackable, IDictionary) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + Dictionary object_names = new(); + object_names[trackable_data.trackable] = trackable_data.object_name; + Dictionary object_map = new(); + object_map[trackable_data.trackable] = trackable_data.object_to_save; + + var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); + var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, + call_with_mapped_captures, saveables_cache: null); + var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); + return (trackable, trackable.serialize_to_tensors()); + } + + private static IDictionary> get_and_write_registered_savers(IDictionary> registered_trackables, TrackableObjectGraph object_graph_proto) + { + Dictionary> registered_savers = new(); + foreach(var pair in registered_trackables) + { + foreach(var td in pair.Value) + { + if (registered_savers.ContainsKey(pair.Key)) + { + registered_savers[pair.Key] = new Dictionary(); + } + else + { + registered_savers[pair.Key][td.object_name] = td.object_to_save; + } + + var object_proto = object_graph_proto.Nodes[td.node_id]; + // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. + } + } + return registered_savers; + } + + private static (IList, IList, IDictionary>) split_trackables(IEnumerable trackable_data) + { + List tensor_trackables = new(); + List py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. + Dictionary> registered_trackables = new(); + + foreach(var td in trackable_data) + { + // TODO: deal with registration. + tensor_trackables.Add(td); + } + return (tensor_trackables, py_state_trackables, registered_trackables); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 7724c6b7..44fa5c5d 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -7,6 +7,7 @@ using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; +using Google.Protobuf; namespace Tensorflow.Checkpoint; @@ -47,19 +48,16 @@ public static class SaveUtilV1 IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { - - Graph target_context; if (to_graph is not null) { - using (to_graph.as_default()) - { - var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + to_graph.as_default(); + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` - var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); - named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); - return (named_saveable_objects, registered_savers); - } + // tensorflow python: `with ops.device("/cpu:0")` + var serialized = graph_proto.ToByteString().ToString(); + var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); } else { diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs new file mode 100644 index 00000000..fa441d79 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint +{ + internal static class SaveableCompat + { + public static string? get_saveable_name(Trackable cls_or_obj) + { + // TODO: implement it with Attribute. + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs deleted file mode 100644 index 7d101d5e..00000000 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs +++ /dev/null @@ -1,109 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Contexts; -using Tensorflow.Eager; - -namespace Tensorflow.Checkpoint; - -public class TrackableSaver -{ - private ObjectGraphView _graph_view; - private EagerTensor _cached_save_operation; - private TrackableObjectGraph _last_save_object_graph; - private Tensor? _object_graph_feed_tensor = null; - private Tensor? _file_prefix_feed_tensor = null; - public TrackableSaver(ObjectGraphView graph_view) - { - _graph_view = graph_view; - - // TODO: cache when not executing eagerly. - // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, - // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` - - } - - private void gather_serialized_tensors(Tensor? object_graph_tensor = null) - { - throw new NotImplementedException(); - } - - private (EagerTensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) - { - throw new NotImplementedException(); - } - - // TODO: parameter write_done_callback - public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, - CheckpointOptions? options = null) - { - if (options is null) - { - options = new CheckpointOptions(); - } - - Dictionary feed_dict = new(); - bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); - if (checkpoint_number is not null) - { - file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; - } - - Tensor file_prefix_tensor; - Tensor object_graph_tensor; - if (use_session) - { - if (_object_graph_feed_tensor is null) - { - // In python there is `with ops.device("/cpu:0")`. - _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); - _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); - } - - object_graph_tensor = _object_graph_feed_tensor; - file_prefix_tensor = _file_prefix_feed_tensor; - feed_dict[file_prefix_tensor] = file_prefix; - } - else - { - // In python there is `with ops.device("/cpu:0")`. - file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); - object_graph_tensor = null; - } - - var (save_path, new_feed_additions) = - save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); - - if (new_feed_additions is not null) - { - foreach (var pair in new_feed_additions) - { - feed_dict.Add(pair.Key, pair.Value); - } - } - if(!use_session) - { - session = null; - } - else if (session is null) - { - session = new Session(); // In python it uses `get_session`. - } - - if (session is not null) - { - var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); - return session.run((Tensor)save_path, s); - } - else if (use_session) - { - throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + - "in graph mode without a default session. Please use " + - "`with tf.Session():` to create a session."); - } - else - { - return save_path; - } - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index ed1f3ec4..6d81d2c9 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -21,9 +21,14 @@ public class TrackableView public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) { obj._maybe_initialize_trackable(); + Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - return obj._trackable_children(save_type); + foreach(var pair in obj._trackable_children(save_type)) + { + children[pair.Key] = pair.Value; + } + return children; } public Trackable Root @@ -50,6 +55,7 @@ public class TrackableView { List bfs_sorted = new(); Queue to_visit = new(); + to_visit.Enqueue(Root); Dictionary> node_paths = new(); node_paths[this.Root] = new List(); while (!to_visit.empty()) diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs new file mode 100644 index 00000000..79109489 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -0,0 +1,191 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Train; +using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +/// +/// Saves and restores a `Trackable` object and its dependencies. +/// +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private Tensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + private Dictionary? _object_map = null; + private object? _cache = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); + + // TODO: cache. + + if(object_graph_tensor is null) + { + // tensorflow python: `with ops.device("/cpu:0"):` + object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + } + else + { + feed_additions[object_graph_tensor] = graph_proto.ToString(); + } + Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + if (serialized_tensors.ContainsKey(Trackable.None)) + { + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + } + return (serialized_tensors, feed_additions, registered_savers, graph_proto); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] { save_op })) + { + _cached_save_operation = array_ops.identity(file_prefix); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] {save_op} )) + { + _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); + object_graph_tensor = null; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs new file mode 100644 index 00000000..759cbd66 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -0,0 +1,36 @@ +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.OptimizerOptions.Types; + +namespace Tensorflow.Checkpoint +{ + /// + /// Saves checkpoints directly from multiple devices. + /// Note that this is a low-level utility which stores Tensors in the keys + /// specified by `SaveableObject`s.Higher-level utilities for object-based + /// checkpointing are built on top of it. + /// + public class MultiDeviceSaver + { + public MultiDeviceSaver(IDictionary> serialized_tensors, + IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) + { + + } + + public Operation? save(string file_prefix, CheckpointOptions? options= null) + { + throw new NotImplementedException(); + } + + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 93413667..fb197eca 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -205,6 +205,16 @@ namespace Tensorflow { slotVariables_ = slot; } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot, + pbc::RepeatedField children + ) + { + OnConstruction(); + slotVariables_ = slot; + children_ = children; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d8f6314b..6f10fd2e 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,4 +1,10 @@ -namespace Tensorflow.Train +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Train { public abstract class AutoTrackable : Trackable { @@ -17,5 +23,48 @@ } } } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + if(save_type != SaveType.SAVEDMODEL) + { + return base._trackable_children(save_type, cache); + } + + Dictionary functions = new(); + // TODO: process of logs. + var properties = this.GetType().GetProperties(); + foreach ( var property in properties ) + { + string name = property.Name; + object value = property.GetValue(this, null); + if(value is Function || value is ConcreteFunction) + { + functions[name] = (Trackable)value; + } + } + + // TODO: process the type `core_types.GenericFunction`. + + Dictionary children = new(); + foreach(var pair in CheckpointDependencies) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) // or Generic function + { + continue; + } + if(functions.ContainsKey(name) && functions[name] != child) + { + throw new ValueError($"Can't save object because it has multiple children with the same " + + $"name. Object: {this}, attribute name: {name}, child 1: " + + $"{child}, child 2: {functions[name]}"); + } + children[name] = child; + } + + return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs index 1ae912ce..393a6a98 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs @@ -28,7 +28,7 @@ namespace Tensorflow public string slice_spec => _slice_spec; private string _name; - public string name => _name; + public string name { get => _name; set => _name = value; } private TF_DataType _dtype; public TF_DataType dtype => _dtype; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index 69235605..cc839952 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -134,35 +134,33 @@ public static partial class SavedModelUtils Dictionary object_map; Dictionary tensor_map; AssetInfo asset_info; - using (var g = exported_graph.as_default()) + exported_graph.as_default(); + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) { - (object_map, tensor_map, asset_info) = saveable_view.map_resources(); - // TODO: deal with signatures. - if (save_custom_gradients) - { - // TODO: trace gradient functions. - } + // TODO: trace gradient functions. + } - foreach (var resource_initializer_function in resource_initializers) - { - // List asset_dependencies = new(); - // TODO: deal with initializers - } - - // using(ops.control_dependencies(...)) - var init_op = control_flow_ops.no_op(); - if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) - { - meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); - } - else - { - meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); - } - // Lack `CopyFrom` API - // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers } - + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + foreach (var obj in object_map.Values) { obj._maybe_initialize_trackable(); @@ -180,11 +178,13 @@ public static partial class SavedModelUtils verify_ops(graph_def, namespace_whitelist); meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef = new(); meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; // TODO: add git version. meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList = new(); meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 98cdb274..622eed3a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -138,5 +138,55 @@ namespace Tensorflow // TODO: implement it. return false; } + + internal static string convert_to_string(string x) + { + return tf.compat.as_str(x); + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private Trackable _obj; + private IList _saveables; + public SaveableCompatibilityConverter(Trackable obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public Trackable Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary serialize_to_tensors() + { + return saveable_objects_to_tensor_dict(_saveables); + } + + /// + /// Converts a list of SaveableObjects to a tensor dictionary. + /// + /// + public static Dictionary saveable_objects_to_tensor_dict(IList saveables) + { + Dictionary tensor_dict = new(); + foreach (var saveable in saveables) + { + foreach(var spec in saveable.specs) + { + var name = saveable_object_util.convert_to_string(spec.name); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + if (!string.IsNullOrEmpty(slice_spec)) + { + throw new NotImplementedException(); + } + else + { + tensor_dict[name] = spec.tensor; + } + } + } + return tensor_dict; + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index dce0be2a..b98075d3 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -34,18 +34,35 @@ namespace Tensorflow.Train public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; } protected int _self_update_uid; - protected IDictionary _unconditional_dependency_names = - new Dictionary(); + protected IDictionary _unconditional_dependency_names; - protected IList _unconditional_checkpoint_dependencies = new List(); + protected IList _unconditional_checkpoint_dependencies; protected IDictionary _self_saveable_object_factories = new Dictionary(); + + private static Trackable _none = new Function(); + /// + /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. + /// The `None` can be any object that inherits `Trackable`. + /// This Property is supposed to be used only internal. + /// + public static Trackable None + { + get + { + return _none; + } + } public virtual string ObjectIdentifier { get => "_generic_user_object"; } - + public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } + public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } + public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } + public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -99,8 +116,9 @@ namespace Tensorflow.Train /// public void _maybe_initialize_trackable() { - // _self_unconditional_checkpoint_dependencies = [] _self_update_uid = -1; + _unconditional_checkpoint_dependencies = new List(); + _unconditional_dependency_names = new Dictionary(); } // TODO: cache @@ -153,6 +171,20 @@ namespace Tensorflow.Train { return _self_saveable_object_factories; } + + /// + /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` + /// if you are defining a custom resource or variable with custom ops. + /// Otherwise, please store the state of your trackable in `tf.Variable` objects + /// and add them to Trackable object hierarchy using `setattr` (for subclasses + /// of `AutoTrackable`) or overriding the `_trackable_children` method. + /// + /// + /// + public virtual IDictionary serialize_to_tensors() + { + throw new NotImplementedException(); + } } public record class TrackableReference(string Name, Trackable Refer);