You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SaveUtilV1.cs 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Linq;
  5. using Tensorflow.Exceptions;
  6. using Tensorflow.Train;
  7. using Tensorflow.Training;
  8. using pbc = global::Google.Protobuf.Collections;
  9. using static Tensorflow.Binding;
  10. using Google.Protobuf;
  11. namespace Tensorflow.Checkpoint;
  12. public static class SaveUtilV1
  13. {
  14. public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names,
  15. IDictionary<Trackable, Trackable>? object_map = null)
  16. {
  17. // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md,
  18. // till now only internal registrations are allowed. So, we won't return a saver in this function.
  19. // The implementation of this function should be updated if tensorflow update it.
  20. Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new();
  21. foreach (var pair in object_names)
  22. {
  23. var trackable = pair.Key;
  24. var object_name = pair.Value;
  25. var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
  26. // skip the registration process.
  27. List<CheckpointFactoryData> current_list = new();
  28. foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save))
  29. {
  30. // treat name as key_suffix.
  31. var name = name_and_factory.Key;
  32. var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name);
  33. current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key));
  34. }
  35. checkpoint_factory_map[trackable] = current_list;
  36. }
  37. return (checkpoint_factory_map, null);
  38. }
  39. public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view,
  40. IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
  41. object? saveables_cache = null)
  42. {
  43. if (to_graph is not null)
  44. {
  45. var g = to_graph.as_default();
  46. var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
  47. object_map, call_with_mapped_captures, saveables_cache);
  48. tf.device("/cpu:0");
  49. var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
  50. named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
  51. g.Exit();
  52. return (named_saveable_objects, registered_savers);
  53. }
  54. else
  55. {
  56. using (new ops.NullContextManager())
  57. {
  58. var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
  59. object_map, call_with_mapped_captures, saveables_cache);
  60. tf.device("/cpu:0");
  61. var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
  62. named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
  63. return (named_saveable_objects, registered_savers);
  64. }
  65. }
  66. }
  67. public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view,
  68. IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
  69. {
  70. var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
  71. Dictionary<Trackable, string> object_names = new();
  72. foreach (var pair in node_paths)
  73. {
  74. object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
  75. }
  76. Dictionary<Trackable, int> node_ids = new();
  77. for (int i = 0; i < trackable_objects.Count; i++)
  78. {
  79. node_ids[trackable_objects[i]] = i;
  80. }
  81. var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
  82. var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables);
  83. var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph(
  84. trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures,
  85. saveables_cache);
  86. CheckPointUtils.add_checkpoint_values_check(object_graph_proto);
  87. return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers);
  88. }
  89. private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects,
  90. IDictionary<Trackable, int> node_ids,
  91. IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
  92. slot_variables)
  93. {
  94. TrackableObjectGraph object_graph_proto = new();
  95. for (int i = 0; i < trackable_objects.Count; i++)
  96. {
  97. var trackable = trackable_objects[i];
  98. Debug.Assert(node_ids[trackable] == i);
  99. TrackableObjectGraph.Types.TrackableObject object_proto;
  100. if (slot_variables.TryGetValue(trackable, out var slots))
  101. {
  102. object_proto = new TrackableObjectGraph.Types.TrackableObject(slots);
  103. }
  104. else
  105. {
  106. object_proto = new TrackableObjectGraph.Types.TrackableObject();
  107. }
  108. object_graph_proto.Nodes.Add(object_proto);
  109. foreach (var child in graph_view.list_children(trackable))
  110. {
  111. object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
  112. { NodeId = node_ids[child.Refer], LocalName = child.Name });
  113. }
  114. }
  115. return object_graph_proto;
  116. }
  117. private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects,
  118. TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids,
  119. IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map,
  120. bool call_with_mapped_captures, object? saveables_cache = null)
  121. {
  122. int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count);
  123. for (int i = 0; i < cnt; i++)
  124. {
  125. Debug.Assert(node_ids[trackable_objects[i]] == i);
  126. }
  127. var (checkpoint_factory_map, unmmaped_registered_savers) =
  128. get_checkpoint_factories_and_keys(object_names, object_map);
  129. // skip the process of registered savers
  130. var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map,
  131. object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache);
  132. return (named_saveable_objects, feed_additions, null);
  133. }
  134. public static (List<MySaveableObject>, object?) generate_saveable_objects(
  135. IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map,
  136. TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids,
  137. IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
  138. {
  139. List<MySaveableObject> named_saveable_objects = new();
  140. foreach (var pair in checkpoint_factory_map)
  141. {
  142. var trackable = pair.Key;
  143. var factory_data_list = pair.Value;
  144. bool fill_object_proto = object_graph_proto is not null && node_ids is not null;
  145. TrackableObjectGraph.Types.TrackableObject object_proto = null!;
  146. if (fill_object_proto)
  147. {
  148. object_proto = object_graph_proto.Nodes[node_ids[trackable]];
  149. }
  150. var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
  151. // skip cache
  152. foreach (var factory_data in factory_data_list)
  153. {
  154. var name = factory_data.name;
  155. var key = factory_data.checkpoint_key;
  156. var maybe_saveable = factory_data.factory;
  157. // TODO: oneflow python has a process with callable `saveable_factory`.
  158. List<MySaveableObject> saveables = new();
  159. if (maybe_saveable.DataType == typeof(MySaveableObject))
  160. {
  161. saveables.Add(maybe_saveable.GetValueB());
  162. }
  163. else
  164. {
  165. saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key));
  166. }
  167. foreach (var saveable in saveables)
  168. {
  169. if (!saveable.name.Contains(key))
  170. {
  171. throw new AssertionError($"The object {trackable} produced a SaveableObject with name " +
  172. $"'{saveable.name}' for attribute '{name}'. Expected a name" +
  173. $" containing '{key}'.");
  174. }
  175. }
  176. // skip the process of PythonState
  177. named_saveable_objects.AddRange(saveables);
  178. if(!fill_object_proto) continue;
  179. // skip the process of `TrackableSaveable` because of lack of APIs.
  180. object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
  181. { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) });
  182. }
  183. }
  184. return (named_saveable_objects, null);
  185. }
  186. }
  187. public record class CheckpointFactoryData
  188. (
  189. Maybe<BaseResourceVariable, MySaveableObject> factory,
  190. string name,
  191. string checkpoint_key
  192. );