You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Trackable.cs 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using System.Linq;
  16. using Tensorflow.ModelSaving;
  17. using static Tensorflow.Binding;
  18. namespace Tensorflow.Train
  19. {
  20. public abstract class Trackable
  21. {
  22. /// <summary>
  23. /// Corresponding to tensorflow/python/trackable/constants.py
  24. /// </summary>
  25. public static class Constants
  26. {
  27. public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH";
  28. public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE";
  29. public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON";
  30. }
  31. protected int _self_update_uid;
  32. protected IDictionary<string, Trackable> _unconditional_dependency_names;
  33. protected IList<TrackableReference> _unconditional_checkpoint_dependencies;
  34. protected IDictionary<string, ResourceVariable> _self_saveable_object_factories =
  35. new Dictionary<string, ResourceVariable>();
  36. private static Trackable _none = new Function();
  37. /// <summary>
  38. /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null.
  39. /// The `None` can be any object that inherits `Trackable`.
  40. /// This Property is supposed to be used only internal.
  41. /// </summary>
  42. public static Trackable None
  43. {
  44. get
  45. {
  46. return _none;
  47. }
  48. }
  49. public virtual string ObjectIdentifier
  50. {
  51. get => "_generic_user_object";
  52. }
  53. public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; }
  54. public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; }
  55. public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
  56. public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }
  57. /// <summary>
  58. /// Restore-on-create for a variable be saved with this `Checkpointable`.
  59. /// </summary>
  60. /// <returns></returns>
  61. protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args)
  62. {
  63. tf_with(ops.init_scope(), delegate
  64. {
  65. #pragma warning disable CS0219 // Variable is assigned but its value is never used
  66. IInitializer checkpoint_initializer = null;
  67. #pragma warning restore CS0219 // Variable is assigned but its value is never used
  68. if (tf.Context.executing_eagerly())
  69. #pragma warning disable CS0642 // Possible mistaken empty statement
  70. ;
  71. #pragma warning restore CS0642 // Possible mistaken empty statement
  72. else
  73. checkpoint_initializer = null;
  74. });
  75. var new_variable = args.Getter(args);
  76. // If we set an initializer and the variable processed it, tracking will not
  77. // assign again. It will add this variable to our dependencies, and if there
  78. // is a non-trivial restoration queued, it will handle that. This also
  79. // handles slot variables.
  80. if (!args.Overwrite || new_variable is RefVariable)
  81. return _track_checkpointable(new_variable, name: args.Name,
  82. overwrite: args.Overwrite);
  83. else
  84. return new_variable;
  85. }
  86. /// <summary>
  87. /// Pop and load any deferred checkpoint restores into `trackable`.
  88. /// </summary>
  89. /// <param name="name"></param>
  90. /// <param name="trackable"></param>
  91. protected void _handle_deferred_dependencies(string name, IVariableV1 trackable)
  92. {
  93. _maybe_initialize_trackable();
  94. // TODO
  95. }
  96. protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string name, bool overwrite = false)
  97. {
  98. return checkpointable;
  99. }
  100. /// <summary>
  101. /// Initialize dependency management.
  102. /// </summary>
  103. public void _maybe_initialize_trackable()
  104. {
  105. _self_update_uid = -1;
  106. _unconditional_checkpoint_dependencies = new List<TrackableReference>();
  107. _unconditional_dependency_names = new Dictionary<string, Trackable>();
  108. }
  109. // TODO: cache
  110. public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
  111. {
  112. _maybe_initialize_trackable();
  113. return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
  114. }
  115. public static Trackable convert_to_trackable(object obj, object? parent = null)
  116. {
  117. if (obj is Trackable)
  118. {
  119. return (Trackable)obj;
  120. }
  121. else
  122. {
  123. throw new NotImplementedException();
  124. }
  125. }
  126. public virtual IDictionary<string, Trackable> deserialization_dependencies(IDictionary<string, Trackable> children)
  127. {
  128. return new Dictionary<string, Trackable>();
  129. }
  130. public virtual (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(
  131. SaveOptions? save_options)
  132. {
  133. return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>());
  134. }
  135. public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null,
  136. IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null)
  137. {
  138. var (self_object_map, self_tensor_map) = map_resources(options);
  139. foreach (var pair in self_object_map)
  140. {
  141. object_map.Add(pair);
  142. }
  143. foreach (var pair in self_tensor_map)
  144. {
  145. tensor_map.Add(pair);
  146. }
  147. return self_tensor_map.Keys.ToList();
  148. }
  149. public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint()
  150. {
  151. return _self_saveable_object_factories;
  152. }
  153. /// <summary>
  154. /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors`
  155. /// if you are defining a custom resource or variable with custom ops.
  156. /// Otherwise, please store the state of your trackable in `tf.Variable` objects
  157. /// and add them to Trackable object hierarchy using `setattr` (for subclasses
  158. /// of `AutoTrackable`) or overriding the `_trackable_children` method.
  159. /// </summary>
  160. /// <returns></returns>
  161. /// <exception cref="NotImplementedException"></exception>
  162. public virtual IDictionary<string, object> serialize_to_tensors()
  163. {
  164. throw new NotImplementedException();
  165. }
  166. }
  167. public record class TrackableReference(string Name, Trackable Refer);
  168. }