You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SaveUtil.cs 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. using OneOf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow.Train;
  8. using Tensorflow.Training;
  9. using Tensorflow.Common.Extensions;
  10. using pbc = global::Google.Protobuf.Collections;
  11. namespace Tensorflow.Checkpoint
  12. {
  13. internal record class TrackableData(
  14. // A trackable in the root Trackable object graph.
  15. Trackable trackable,
  16. // The index at which the Trackable appears in TrackableObjectGraph.nodes.
  17. int node_id,
  18. // The BFS-generated path from the root object / used to generate readable checkpoint keys.
  19. string object_name,
  20. // A list of ObjectReference for each child connected to this Trackable.
  21. pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto,
  22. // A list of SlotVariableReference to save to the object (only valid for Optimizer objects).
  23. pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto,
  24. // The object to save to checkpoint. Usually this is the same as `trackable`,
  25. // but can differ when the the caller wants to specify a different object to
  26. // save. For example, when saving checkpoints asynchronously, variables are
  27. // copied to the CPU. `object_to_save` is set as the copied variable.
  28. Trackable object_to_save
  29. );
  30. public static class SaveUtil
  31. {
  32. public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
  33. serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
  34. {
  35. var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
  36. var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data);
  37. var object_graph_proto = fill_object_graph_proto(trackable_data);
  38. var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto);
  39. var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto);
  40. Dictionary<Tensor, object> feed_additions;
  41. if(cache is null)
  42. {
  43. feed_additions = null;
  44. serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures,
  45. cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value);
  46. }
  47. else
  48. {
  49. feed_additions = null;
  50. // TODO: deal with cache.
  51. throw new NotFiniteNumberException();
  52. }
  53. CheckPointUtils.add_checkpoint_values_check(object_graph_proto);
  54. return (serialized_tensors, feed_additions, registered_savers, object_graph_proto);
  55. }
  56. private static (IList<TrackableData>, IDictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
  57. {
  58. var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
  59. Dictionary<Trackable, string> object_names = new();
  60. foreach(var pair in node_paths)
  61. {
  62. object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
  63. }
  64. Dictionary<Trackable, int> node_ids = new();
  65. for(int i = 0; i < trackable_objects.Count; i++)
  66. {
  67. node_ids[trackable_objects[i]] = i;
  68. }
  69. var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
  70. List<TrackableData> trackable_data = new();
  71. foreach(var trackable in trackable_objects)
  72. {
  73. pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new();
  74. foreach(var child in graph_view.list_children(trackable))
  75. {
  76. children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
  77. {
  78. NodeId = node_ids[child.Refer],
  79. LocalName = child.Name
  80. });
  81. }
  82. slot_variables.TryGetValue(trackable, out var slot_variable);
  83. trackable_data.Add(new TrackableData(
  84. trackable: trackable,
  85. node_id: node_ids[trackable],
  86. object_name: object_names[trackable],
  87. children_proto: children_proto,
  88. slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(),
  89. object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map)
  90. ));
  91. }
  92. return (trackable_data, node_ids);
  93. }
  94. private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data)
  95. {
  96. TrackableObjectGraph object_graph_proto = new();
  97. for(int i = 0; i < trackable_data.Count; i++)
  98. {
  99. var td = trackable_data[i];
  100. Debug.Assert(td.node_id == i);
  101. object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
  102. }
  103. return object_graph_proto;
  104. }
  105. /// <summary>
  106. /// Creates dictionary of tensors to checkpoint, and updates the proto.
  107. /// </summary>
  108. /// <param name="tensor_trackables"></param>
  109. /// <param name="node_ids"></param>
  110. /// <param name="call_with_mapped_captures"></param>
  111. /// <param name="cache"></param>
  112. /// <param name="object_graph_proto"></param>
  113. private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
  114. bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
  115. {
  116. Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
  117. foreach(var td in tensor_trackables)
  118. {
  119. // TODO: deal with cache.
  120. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
  121. Trackable trackable = null;
  122. IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict;
  123. if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
  124. {
  125. (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
  126. }
  127. else
  128. {
  129. tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto);
  130. trackable = td.object_to_save;
  131. }
  132. if(trackable is not null)
  133. {
  134. serialized_tensors[trackable] = tensor_dict;
  135. }
  136. else
  137. {
  138. serialized_tensors[Trackable.None] = tensor_dict;
  139. }
  140. }
  141. return serialized_tensors;
  142. }
  143. private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
  144. {
  145. var trackable = trackable_data.object_to_save;
  146. // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
  147. IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict;
  148. if (call_with_mapped_captures)
  149. {
  150. throw new NotImplementedException();
  151. }
  152. else
  153. {
  154. ret_tensor_dict = trackable.serialize_to_tensors();
  155. }
  156. Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new();
  157. foreach(var pair in ret_tensor_dict)
  158. {
  159. var local_name = TrackableUtils.escape_local_name(pair.Key);
  160. var maybe_tensor = pair.Value;
  161. var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name);
  162. tensor_dict[checkpoint_key] = maybe_tensor;
  163. foreach(var key in maybe_tensor.Keys)
  164. {
  165. if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>())
  166. {
  167. maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name;
  168. }
  169. }
  170. if(object_graph_proto is not null)
  171. {
  172. object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
  173. {
  174. Name = local_name,
  175. CheckpointKey = checkpoint_key,
  176. FullName = CheckPointUtils.get_full_name(trackable)
  177. });
  178. }
  179. }
  180. return tensor_dict;
  181. }
  182. /// <summary>
  183. /// Gets tensors to serialize from a Trackable with legacy SaveableObjects.
  184. /// </summary>
  185. /// <param name="trackable_data"></param>
  186. /// <param name="node_ids"></param>
  187. /// <param name="call_with_mapped_captures"></param>
  188. /// <param name="object_graph_proto"></param>
  189. /// <returns></returns>
  190. private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
  191. bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
  192. {
  193. Dictionary<Trackable, string> object_names = new();
  194. object_names[trackable_data.trackable] = trackable_data.object_name;
  195. Dictionary<Trackable, Trackable> object_map = new();
  196. object_map[trackable_data.trackable] = trackable_data.object_to_save;
  197. var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map);
  198. var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map,
  199. call_with_mapped_captures, saveables_cache: null);
  200. var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects);
  201. return (trackable, trackable.serialize_to_tensors());
  202. }
  203. private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto)
  204. {
  205. Dictionary<string, IDictionary<string, Trackable>> registered_savers = new();
  206. foreach(var pair in registered_trackables)
  207. {
  208. foreach(var td in pair.Value)
  209. {
  210. if (registered_savers.ContainsKey(pair.Key))
  211. {
  212. registered_savers[pair.Key] = new Dictionary<string, Trackable>();
  213. }
  214. else
  215. {
  216. registered_savers[pair.Key][td.object_name] = td.object_to_save;
  217. }
  218. var object_proto = object_graph_proto.Nodes[td.node_id];
  219. // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`.
  220. }
  221. }
  222. return registered_savers;
  223. }
  224. private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data)
  225. {
  226. List<TrackableData> tensor_trackables = new();
  227. List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder.
  228. Dictionary<string, IList<TrackableData>> registered_trackables = new();
  229. foreach(var td in trackable_data)
  230. {
  231. // TODO: deal with registration.
  232. tensor_trackables.Add(td);
  233. }
  234. return (tensor_trackables, py_state_trackables, registered_trackables);
  235. }
  236. }
  237. }