using System; using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Keras.Metrics; using Tensorflow.Train; namespace Tensorflow.Keras.Saving.SavedModel { // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#, // Using the name "Attributes" may be quite confusing. /// /// Class that tracks and validates all serialization attributes. /// public abstract class SerializedAttributes: ISerializedAttributes { protected IDictionary _object_dict; protected IDictionary _function_dict; protected AutoTrackable _keras_trackable; internal HashSet _all_functions; internal HashSet _all_checkpointable_objects; private SerializedAttributes() { _object_dict= new Dictionary(); _function_dict= new Dictionary(); _keras_trackable= new AutoTrackable(); _all_functions= new HashSet(); _all_checkpointable_objects= new HashSet(); } protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) { _object_dict = new Dictionary(); _function_dict = new Dictionary(); _keras_trackable = new AutoTrackable(); _all_checkpointable_objects = new HashSet(checkpointable_objects); _all_functions = new HashSet(functions); } protected SerializedAttributes((IEnumerable, IEnumerable) objects_and_functions) { _object_dict = new Dictionary(); _function_dict = new Dictionary(); _keras_trackable = new AutoTrackable(); _all_checkpointable_objects = new HashSet(objects_and_functions.Item1); _all_functions = new HashSet(objects_and_functions.Item2); } public IDictionary Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); public IDictionary CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); /// /// Returns functions to attach to the root object during serialization. /// public IDictionary FunctionsToSerialize { get { Dictionary functions = new(); foreach(var pair in Functions) { if (_all_functions.Contains(pair.Key)) { // TODO: deal with `LayerCall`. functions[pair.Key] = pair.Value; } } return functions; } } /// /// Returns objects to attach to the root object during serialization. /// public IDictionary ObjectsToSerialize { get { var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); objects[Constants.KERAS_ATTR] = _keras_trackable; return objects; } } /// /// Saves function dictionary, and validates dictionary values. /// /// public IDictionary set_and_validate_functions(IDictionary function_dict) { foreach(var key in _all_functions) { if (function_dict.ContainsKey(key)) { // TODO: deal with type `LayerCall`. var fn = function_dict[key]; if (fn is not null && (fn is not Function)) { throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key})."); } _function_dict[key] = fn; var tf_fn = fn; // TODO: deal with type `LayerCall`. // Warning: this implmentation should be considered again. var properties = _keras_trackable.GetType().GetProperties(); foreach (var property in properties) { if(property.Name == key) { property.SetValue(_keras_trackable, tf_fn); break; } } } else { // high priority // TODO(Rinne): complete the implementation. continue; //throw new ValueError($"Function {key} missing from serialized function dict."); } } return Functions; } /// /// Saves objects to a dictionary, and validates the values. /// /// public IDictionary set_and_validate_objects(IDictionary object_dict) { foreach(var key in _all_checkpointable_objects) { if (object_dict.ContainsKey(key)) { _object_dict[key] = object_dict[key]; // Warning: this implmentation should be considered again. var properties = _keras_trackable.GetType().GetProperties(); foreach (var property in properties) { if (property.Name == key) { property.SetValue(_keras_trackable, object_dict[key]); break; } } } else { // high priority. // TODO(Rinne): Add the implementation. continue; //throw new ValueError($"Object {key} missing from serialized object dict."); } } return CheckpointableObjects; } /// /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python). /// /// public static SerializedAttributes Create(Trackable obj) { if(obj is Model) { return new ModelAttributes(); } else if(obj is Metric) { return new MetricAttributes(); } else if(obj is RNN) { return new RNNAttributes(); } else if(obj is Layer) { return new LayerAttributes(); } else { throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}"); } } protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) { return (checkpointable_objects ?? (new List()), functions ?? (new List())); } } // Note that the current implementation still has some potential risks. // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras". // However, currently it's just a normal class. public class CommonEndPoints: SerializedAttributes { public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) { } public CommonEndPoints() : base(new string[] { "variables", "trainable_variables", "regularization_losses" }, new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) { } } public class LayerAttributes: CommonEndPoints { public LayerAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })) { } public LayerAttributes() : //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) base(new string[] { "non_trainable_variables", "layers" }, new string[] { }) { } } public class ModelAttributes: LayerAttributes { public ModelAttributes(IEnumerable checkpointable_objects, IEnumerable functions): base(checkpointable_objects, functions) { } public ModelAttributes(): base() { } } public class MetricAttributes : SerializedAttributes { public MetricAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : base(checkpointable_objects.Concat(new string[] { "variables" }), functions) { } public MetricAttributes() : base(new string[] { "variables" }, new string[] {}) { } } public class RNNAttributes: LayerAttributes { public RNNAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : base(checkpointable_objects, functions.Concat(new string[] {"states"})) { } public RNNAttributes() : base(new string[] { }, new string[] { "states" }) { } } }