You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SaveUtil.cs 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow.Train;
  7. using Tensorflow.Training;
  8. using pbc = global::Google.Protobuf.Collections;
  9. namespace Tensorflow.Checkpoint
  10. {
  11. internal record class TrackableData(
  12. // A trackable in the root Trackable object graph.
  13. Trackable trackable,
  14. // The index at which the Trackable appears in TrackableObjectGraph.nodes.
  15. int node_id,
  16. // The BFS-generated path from the root object / used to generate readable checkpoint keys.
  17. string object_name,
  18. // A list of ObjectReference for each child connected to this Trackable.
  19. pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto,
  20. // A list of SlotVariableReference to save to the object (only valid for Optimizer objects).
  21. pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto,
  22. // The object to save to checkpoint. Usually this is the same as `trackable`,
  23. // but can differ when the the caller wants to specify a different object to
  24. // save. For example, when saving checkpoints asynchronously, variables are
  25. // copied to the CPU. `object_to_save` is set as the copied variable.
  26. Trackable object_to_save
  27. );
  28. public static class SaveUtil
  29. {
  30. public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
  31. serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
  32. {
  33. var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
  34. var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data);
  35. var object_graph_proto = fill_object_graph_proto(trackable_data);
  36. var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto);
  37. var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto);
  38. Dictionary<Tensor, object> feed_additions;
  39. if(cache is null)
  40. {
  41. feed_additions = null;
  42. serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures,
  43. cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value);
  44. }
  45. else
  46. {
  47. feed_additions = null;
  48. // TODO: deal with cache.
  49. throw new NotFiniteNumberException();
  50. }
  51. CheckPointUtils.add_checkpoint_values_check(object_graph_proto);
  52. return (serialized_tensors, feed_additions, registered_savers, object_graph_proto);
  53. }
  54. private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map)
  55. {
  56. var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
  57. Dictionary<Trackable, string> object_names = new();
  58. foreach(var pair in node_paths)
  59. {
  60. object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
  61. }
  62. Dictionary<Trackable, int> node_ids = new();
  63. for(int i = 0; i < trackable_objects.Count; i++)
  64. {
  65. node_ids[trackable_objects[i]] = i;
  66. }
  67. var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
  68. List<TrackableData> trackable_data = new();
  69. foreach(var trackable in trackable_objects)
  70. {
  71. pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new();
  72. foreach(var child in graph_view.list_children(trackable))
  73. {
  74. children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
  75. {
  76. NodeId = node_ids[child.Refer],
  77. LocalName = child.Name
  78. });
  79. }
  80. slot_variables.TryGetValue(trackable, out var slot_variable);
  81. trackable_data.Add(new TrackableData(
  82. trackable: trackable,
  83. node_id: node_ids[trackable],
  84. object_name: object_names[trackable],
  85. children_proto: children_proto,
  86. slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(),
  87. object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map)
  88. ));
  89. }
  90. return (trackable_data, node_ids);
  91. }
  92. private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data)
  93. {
  94. TrackableObjectGraph object_graph_proto = new();
  95. for(int i = 0; i < trackable_data.Count; i++)
  96. {
  97. var td = trackable_data[i];
  98. Debug.Assert(td.node_id == i);
  99. object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
  100. }
  101. return object_graph_proto;
  102. }
  103. /// <summary>
  104. /// Creates dictionary of tensors to checkpoint, and updates the proto.
  105. /// </summary>
  106. /// <param name="tensor_trackables"></param>
  107. /// <param name="node_ids"></param>
  108. /// <param name="call_with_mapped_captures"></param>
  109. /// <param name="cache"></param>
  110. /// <param name="object_graph_proto"></param>
  111. private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
  112. bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
  113. {
  114. Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
  115. foreach(var td in tensor_trackables)
  116. {
  117. // TODO: deal with cache.
  118. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
  119. Trackable trackable = null;
  120. IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
  121. if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
  122. {
  123. (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
  124. }
  125. else
  126. {
  127. tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto);
  128. trackable = td.object_to_save;
  129. }
  130. if(trackable is not null)
  131. {
  132. serialized_tensors[trackable] = tensor_dict;
  133. }
  134. else
  135. {
  136. serialized_tensors[Trackable.None] = tensor_dict;
  137. }
  138. }
  139. return serialized_tensors;
  140. }
  141. private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
  142. {
  143. var trackable = trackable_data.object_to_save;
  144. // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
  145. IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
  146. if (call_with_mapped_captures)
  147. {
  148. throw new NotImplementedException();
  149. }
  150. else
  151. {
  152. ret_tensor_dict = trackable.serialize_to_tensors();
  153. }
  154. // TODO: deal with the type `SaveSpce` (currently it will never be it).
  155. Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
  156. foreach(var pair in ret_tensor_dict)
  157. {
  158. var local_name = TrackableUtils.escape_local_name(pair.Key);
  159. var maybe_tensor = pair.Value;
  160. var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name);
  161. tensor_dict[checkpoint_key] = maybe_tensor;
  162. if(maybe_tensor.GetValueA() is SaveSpec)
  163. {
  164. throw new NotImplementedException();
  165. //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
  166. }
  167. if(object_graph_proto is not null)
  168. {
  169. object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
  170. {
  171. Name = local_name,
  172. CheckpointKey = checkpoint_key,
  173. FullName = CheckPointUtils.get_full_name(trackable)
  174. });
  175. }
  176. }
  177. return tensor_dict;
  178. }
  179. /// <summary>
  180. /// Gets tensors to serialize from a Trackable with legacy SaveableObjects.
  181. /// </summary>
  182. /// <param name="trackable_data"></param>
  183. /// <param name="node_ids"></param>
  184. /// <param name="call_with_mapped_captures"></param>
  185. /// <param name="object_graph_proto"></param>
  186. /// <returns></returns>
  187. private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
  188. bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
  189. {
  190. Dictionary<Trackable, string> object_names = new();
  191. object_names[trackable_data.trackable] = trackable_data.object_name;
  192. Dictionary<Trackable, Trackable> object_map = new();
  193. object_map[trackable_data.trackable] = trackable_data.object_to_save;
  194. var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map);
  195. var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map,
  196. call_with_mapped_captures, saveables_cache: null);
  197. var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects);
  198. return (trackable, trackable.serialize_to_tensors());
  199. }
  200. private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto)
  201. {
  202. Dictionary<string, IDictionary<string, Trackable>> registered_savers = new();
  203. foreach(var pair in registered_trackables)
  204. {
  205. foreach(var td in pair.Value)
  206. {
  207. if (registered_savers.ContainsKey(pair.Key))
  208. {
  209. registered_savers[pair.Key] = new Dictionary<string, Trackable>();
  210. }
  211. else
  212. {
  213. registered_savers[pair.Key][td.object_name] = td.object_to_save;
  214. }
  215. var object_proto = object_graph_proto.Nodes[td.node_id];
  216. // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`.
  217. }
  218. }
  219. return registered_savers;
  220. }
  221. private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data)
  222. {
  223. List<TrackableData> tensor_trackables = new();
  224. List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder.
  225. Dictionary<string, IList<TrackableData>> registered_trackables = new();
  226. foreach(var td in trackable_data)
  227. {
  228. // TODO: deal with registration.
  229. tensor_trackables.Add(td);
  230. }
  231. return (tensor_trackables, py_state_trackables, registered_trackables);
  232. }
  233. }
  234. }