You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialized_attributes.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Keras.Engine;
  6. using Tensorflow.Keras.Layers.Rnn;
  7. using Tensorflow.Keras.Metrics;
  8. using Tensorflow.Train;
  9. namespace Tensorflow.Keras.Saving.SavedModel
  10. {
  11. // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#,
  12. // Using the name "Attributes" may be quite confusing.
  13. /// <summary>
  14. /// Class that tracks and validates all serialization attributes.
  15. /// </summary>
  16. public abstract class SerializedAttributes: ISerializedAttributes
  17. {
  18. protected IDictionary<string, Trackable?> _object_dict;
  19. protected IDictionary<string, Trackable?> _function_dict;
  20. protected AutoTrackable _keras_trackable;
  21. internal HashSet<string> _all_functions;
  22. internal HashSet<string> _all_checkpointable_objects;
  23. private SerializedAttributes()
  24. {
  25. _object_dict= new Dictionary<string, Trackable?>();
  26. _function_dict= new Dictionary<string, Trackable?>();
  27. _keras_trackable= new AutoTrackable();
  28. _all_functions= new HashSet<string>();
  29. _all_checkpointable_objects= new HashSet<string>();
  30. }
  31. protected SerializedAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions)
  32. {
  33. _object_dict = new Dictionary<string, Trackable?>();
  34. _function_dict = new Dictionary<string, Trackable?>();
  35. _keras_trackable = new AutoTrackable();
  36. _all_checkpointable_objects = new HashSet<string>(checkpointable_objects);
  37. _all_functions = new HashSet<string>(functions);
  38. }
  39. protected SerializedAttributes((IEnumerable<string>, IEnumerable<string>) objects_and_functions)
  40. {
  41. _object_dict = new Dictionary<string, Trackable?>();
  42. _function_dict = new Dictionary<string, Trackable?>();
  43. _keras_trackable = new AutoTrackable();
  44. _all_checkpointable_objects = new HashSet<string>(objects_and_functions.Item1);
  45. _all_functions = new HashSet<string>(objects_and_functions.Item2);
  46. }
  47. public IDictionary<string, Trackable> Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);
  48. public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);
  49. /// <summary>
  50. /// Returns functions to attach to the root object during serialization.
  51. /// </summary>
  52. public IDictionary<string, Trackable> FunctionsToSerialize
  53. {
  54. get
  55. {
  56. Dictionary<string, Trackable> functions = new();
  57. foreach(var pair in Functions)
  58. {
  59. if (_all_functions.Contains(pair.Key))
  60. {
  61. // TODO: deal with `LayerCall`.
  62. functions[pair.Key] = pair.Value;
  63. }
  64. }
  65. return functions;
  66. }
  67. }
  68. /// <summary>
  69. /// Returns objects to attach to the root object during serialization.
  70. /// </summary>
  71. public IDictionary<string, Trackable> ObjectsToSerialize
  72. {
  73. get
  74. {
  75. var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value);
  76. objects[Constants.KERAS_ATTR] = _keras_trackable;
  77. return objects;
  78. }
  79. }
  80. /// <summary>
  81. /// Saves function dictionary, and validates dictionary values.
  82. /// </summary>
  83. /// <param name="function_dict"></param>
  84. public IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict)
  85. {
  86. foreach(var key in _all_functions)
  87. {
  88. if (function_dict.ContainsKey(key))
  89. {
  90. // TODO: deal with type `LayerCall`.
  91. var fn = function_dict[key];
  92. if (fn is not null && (fn is not Function))
  93. {
  94. throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key}).");
  95. }
  96. _function_dict[key] = fn;
  97. var tf_fn = fn; // TODO: deal with type `LayerCall`.
  98. // Warning: this implmentation should be considered again.
  99. var properties = _keras_trackable.GetType().GetProperties();
  100. foreach (var property in properties)
  101. {
  102. if(property.Name == key)
  103. {
  104. property.SetValue(_keras_trackable, tf_fn);
  105. break;
  106. }
  107. }
  108. }
  109. else
  110. {
  111. // high priority
  112. // TODO(Rinne): complete the implementation.
  113. continue;
  114. //throw new ValueError($"Function {key} missing from serialized function dict.");
  115. }
  116. }
  117. return Functions;
  118. }
  119. /// <summary>
  120. /// Saves objects to a dictionary, and validates the values.
  121. /// </summary>
  122. /// <param name="object_dict"></param>
  123. public IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict)
  124. {
  125. foreach(var key in _all_checkpointable_objects)
  126. {
  127. if (object_dict.ContainsKey(key))
  128. {
  129. _object_dict[key] = object_dict[key];
  130. // Warning: this implmentation should be considered again.
  131. var properties = _keras_trackable.GetType().GetProperties();
  132. foreach (var property in properties)
  133. {
  134. if (property.Name == key)
  135. {
  136. property.SetValue(_keras_trackable, object_dict[key]);
  137. break;
  138. }
  139. }
  140. }
  141. else
  142. {
  143. // high priority.
  144. // TODO(Rinne): Add the implementation.
  145. continue;
  146. //throw new ValueError($"Object {key} missing from serialized object dict.");
  147. }
  148. }
  149. return CheckpointableObjects;
  150. }
  151. /// <summary>
  152. /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python).
  153. /// </summary>
  154. /// <returns></returns>
  155. public static SerializedAttributes Create(Trackable obj)
  156. {
  157. if(obj is Model)
  158. {
  159. return new ModelAttributes();
  160. }
  161. else if(obj is Metric)
  162. {
  163. return new MetricAttributes();
  164. }
  165. else if(obj is RNN)
  166. {
  167. return new RNNAttributes();
  168. }
  169. else if(obj is Layer)
  170. {
  171. return new LayerAttributes();
  172. }
  173. else
  174. {
  175. throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}");
  176. }
  177. }
  178. protected virtual (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
  179. {
  180. return (checkpointable_objects ?? (new List<string>()), functions ?? (new List<string>()));
  181. }
  182. }
  183. // Note that the current implementation still has some potential risks.
  184. // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras".
  185. // However, currently it's just a normal class.
  186. public class CommonEndPoints: SerializedAttributes
  187. {
  188. public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) :
  189. base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }),
  190. functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }))
  191. {
  192. }
  193. public CommonEndPoints() :
  194. base(new string[] { "variables", "trainable_variables", "regularization_losses" },
  195. new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })
  196. {
  197. }
  198. }
  199. public class LayerAttributes: CommonEndPoints
  200. {
  201. public LayerAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) :
  202. //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }),
  203. // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
  204. base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}),
  205. functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }))
  206. {
  207. }
  208. public LayerAttributes() :
  209. //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" },
  210. // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
  211. base(new string[] { "non_trainable_variables", "layers" },
  212. new string[] { })
  213. {
  214. }
  215. }
  216. public class ModelAttributes: LayerAttributes
  217. {
  218. public ModelAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions):
  219. base(checkpointable_objects, functions)
  220. {
  221. }
  222. public ModelAttributes(): base()
  223. {
  224. }
  225. }
  226. public class MetricAttributes : SerializedAttributes
  227. {
  228. public MetricAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) :
  229. base(checkpointable_objects.Concat(new string[] { "variables" }), functions)
  230. {
  231. }
  232. public MetricAttributes() :
  233. base(new string[] { "variables" }, new string[] {})
  234. {
  235. }
  236. }
  237. public class RNNAttributes: LayerAttributes
  238. {
  239. public RNNAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) :
  240. base(checkpointable_objects, functions.Concat(new string[] {"states"}))
  241. {
  242. }
  243. public RNNAttributes() :
  244. base(new string[] { }, new string[] { "states" })
  245. {
  246. }
  247. }
  248. }