using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using Tensorflow.Exceptions; using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; using Google.Protobuf; namespace Tensorflow.Checkpoint; public static class SaveUtilV1 { public static (Dictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, IDictionary? object_map = null) { // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, // till now only internal registrations are allowed. So, we won't return a saver in this function. // The implementation of this function should be updated if tensorflow update it. Dictionary> checkpoint_factory_map = new(); foreach (var pair in object_names) { var trackable = pair.Key; var object_name = pair.Value; var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); // skip the registration process. List current_list = new(); foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) { // treat name as key_suffix. var name = name_and_factory.Key; var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); } checkpoint_factory_map[trackable] = current_list; } return (checkpoint_factory_map, null); } public static (List, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { if (to_graph is not null) { var g = to_graph.as_default(); var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); tf.device("/cpu:0"); var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); g.Exit(); return (named_saveable_objects, registered_savers); } else { using (new ops.NullContextManager()) { var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); tf.device("/cpu:0"); var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); return (named_saveable_objects, registered_savers); } } } public static (List, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); Dictionary object_names = new(); foreach (var pair in node_paths) { object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); } Dictionary node_ids = new(); for (int i = 0; i < trackable_objects.Count; i++) { node_ids[trackable_objects[i]] = i; } var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, saveables_cache); CheckPointUtils.add_checkpoint_values_check(object_graph_proto); return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); } private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList trackable_objects, IDictionary node_ids, IDictionary> slot_variables) { TrackableObjectGraph object_graph_proto = new(); for (int i = 0; i < trackable_objects.Count; i++) { var trackable = trackable_objects[i]; Debug.Assert(node_ids[trackable] == i); TrackableObjectGraph.Types.TrackableObject object_proto; if (slot_variables.TryGetValue(trackable, out var slots)) { object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); } else { object_proto = new TrackableObjectGraph.Types.TrackableObject(); } object_graph_proto.Nodes.Add(object_proto); foreach (var child in graph_view.list_children(trackable)) { object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() { NodeId = node_ids[child.Refer], LocalName = child.Name }); } } return object_graph_proto; } private static (List, object?, IDictionary>?) add_attributes_to_object_graph(IList trackable_objects, TrackableObjectGraph object_graph_proto, IDictionary node_ids, IDictionary object_names, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); for (int i = 0; i < cnt; i++) { Debug.Assert(node_ids[trackable_objects[i]] == i); } var (checkpoint_factory_map, unmmaped_registered_savers) = get_checkpoint_factories_and_keys(object_names, object_map); // skip the process of registered savers var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); return (named_saveable_objects, feed_additions, null); } public static (List, object?) generate_saveable_objects( IDictionary> checkpoint_factory_map, TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { List named_saveable_objects = new(); foreach (var pair in checkpoint_factory_map) { var trackable = pair.Key; var factory_data_list = pair.Value; bool fill_object_proto = object_graph_proto is not null && node_ids is not null; TrackableObjectGraph.Types.TrackableObject object_proto = null!; if (fill_object_proto) { object_proto = object_graph_proto.Nodes[node_ids[trackable]]; } var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); // skip cache foreach (var factory_data in factory_data_list) { var name = factory_data.name; var key = factory_data.checkpoint_key; var maybe_saveable = factory_data.factory; // TODO: oneflow python has a process with callable `saveable_factory`. List saveables = new(); if (maybe_saveable.DataType == typeof(MySaveableObject)) { saveables.Add(maybe_saveable.GetValueB()); } else { saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); } foreach (var saveable in saveables) { if (!saveable.name.Contains(key)) { throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + $"'{saveable.name}' for attribute '{name}'. Expected a name" + $" containing '{key}'."); } } // skip the process of PythonState named_saveable_objects.AddRange(saveables); if(!fill_object_proto) continue; // skip the process of `TrackableSaveable` because of lack of APIs. object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); } } return (named_saveable_objects, null); } } public record class CheckpointFactoryData ( Maybe factory, string name, string checkpoint_key );