From 726b742157eb2bdc4f5063e9a0f9093bfd03aa33 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 14 Jan 2023 15:45:02 +0800 Subject: [PATCH 01/10] Add check for dims of x and y in model.fit. --- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index e0b4af78..40dd4ab6 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -31,6 +31,11 @@ namespace Tensorflow.Keras.Engine int workers = 1, bool use_multiprocessing = false) { + if (x.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); + } int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); var train_x = x[new Slice(0, train_count)]; var train_y = y[new Slice(0, train_count)]; From bb8168b5ca9bc78a824d429eb7bd5f4ac9e4fa8d Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 21 Jan 2023 11:07:07 +0800 Subject: [PATCH 02/10] Init the serialization of keras pb model. --- src/TensorFlowNET.Core/APIs/tf.compat.cs | 22 ++ .../Checkpoint/CheckPointUtils.cs | 150 +++++++++ .../Checkpoint/CheckpointOptions.cs | 5 + .../Checkpoint/ObjectGraphView.cs | 63 ++++ .../Checkpoint/SaveUtilV1.cs | 229 ++++++++++++++ .../Checkpoint/TrackableSaver.cs | 109 +++++++ .../Checkpoint/TrackableView.cs | 75 +++++ .../Exceptions/AssertionError.cs | 14 + .../Framework/meta_graph.cs | 63 +++- .../Functions/ConcreteFunction.cs | 3 +- src/TensorFlowNET.Core/Functions/Function.cs | 11 +- .../ModelSaving/SaveOptions.cs | 8 +- .../Operations/resource_variable_ops.cs | 6 + .../Protobuf/SavedObjectGraph.cs | 10 +- .../Protobuf/TrackableObjectGraph.cs | 6 + .../Training/AutoTrackable.cs | 15 + src/TensorFlowNET.Core/Training/Optimizer.cs | 7 +- .../Training/Saving/SaveableObject.cs | 14 + .../Training/Saving/SavedModel/AssetInfo.cs | 11 + .../Saving/SavedModel/AugmentedGraphView.cs | 60 ++++ .../Training/Saving/SavedModel/Constants.cs | 33 ++ .../Saving/SavedModel/RevivedTypes.cs | 17 + .../Training/Saving/SavedModel/SaveType.cs | 9 + .../Saving/SavedModel/SaveableView.cs | 299 ++++++++++++++++++ .../Saving/SavedModel/TagConstants.cs | 10 + .../Training/Saving/SavedModel/builder.cs | 22 ++ .../Training/Saving/SavedModel/save.cs | 256 +++++++++++++++ .../SavedModel/signature_serialization.cs | 58 ++++ .../Training/Saving/SavedModel/utils.cs | 52 +++ .../Saving/saveable_object_util.py.cs | 19 +- src/TensorFlowNET.Core/Training/Trackable.cs | 79 ++++- .../Training/TrackableUtils.cs | 148 +++++++++ .../Variables/BaseResourceVariable.cs | 1 + src/TensorFlowNET.Core/ops.cs | 18 ++ .../Engine/Layer.Serialize.cs | 31 ++ src/TensorFlowNET.Keras/Engine/Layer.cs | 4 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 15 +- src/TensorFlowNET.Keras/Engine/Model.cs | 6 + .../Protobuf/SavedMetadata.cs | 12 + src/TensorFlowNET.Keras/Protobuf/Versions.cs | 7 + .../Saving/SavedModel/Constants.cs | 41 +++ .../Saving/SavedModel/KerasObjectWrapper.cs | 11 + .../Saving/SavedModel/Save.cs | 115 +++++++ .../Saving/SavedModel/SaveImpl.cs | 19 ++ .../Saving/SavedModel/base_serialization.cs | 40 +++ .../Saving/SavedModel/layer_serialization.cs | 62 ++++ .../Saving/SavedModel/utils.cs | 33 ++ test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 60 ++++ .../Tensorflow.Binding.UnitTest.csproj | 2 +- 49 files changed, 2347 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableView.cs create mode 100644 src/TensorFlowNET.Core/Exceptions/AssertionError.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs create mode 100644 src/TensorFlowNET.Core/Training/TrackableUtils.cs create mode 100644 src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/SaveTest.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.cs b/src/TensorFlowNET.Core/APIs/tf.compat.cs index 4d979eb5..5b2b5a10 100644 --- a/src/TensorFlowNET.Core/APIs/tf.compat.cs +++ b/src/TensorFlowNET.Core/APIs/tf.compat.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Text; + namespace Tensorflow { public partial class tensorflow @@ -23,6 +25,26 @@ namespace Tensorflow public class CompatApi { public CompatV1Api v1 { get; } = new CompatV1Api(); + + internal string as_text(string bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return bytes_or_text; + } + internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return encoding.GetString(bytes_or_text); + } + + internal string as_str(string bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } + internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } } public bool executing_eagerly() diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs new file mode 100644 index 00000000..70d77155 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -0,0 +1,150 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint; + +public static class CheckPointUtils +{ + private static string _ESCAPE_CHAR = "."; + public static (List, Dictionary>, Dictionary, + IDictionary>, + Dictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); + return (trackable_objects, node_paths, node_ids, slot_variables, object_names); + } + + public static + IDictionary> + serialize_slot_variables(IEnumerable trackable_objects, + IDictionary node_ids, IDictionary object_names) + { + var non_slot_objects = trackable_objects.ToList(); + Dictionary> + slot_variables = new(); + foreach (var trackable in non_slot_objects) + { + if (trackable is not Optimizer) + { + continue; + } + + var optim = (Optimizer)trackable; + var slot_names = optim.get_slot_names(); + foreach (var slot_name in slot_names) + { + for (int original_variable_node_id = 0; + original_variable_node_id < non_slot_objects.Count; + original_variable_node_id++) + { + var original_variable = non_slot_objects[original_variable_node_id]; + IVariableV1 slot_variable; + if (original_variable is not IVariableV1) + { + slot_variable = null; + } + slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); + if(slot_variable is null) continue; + + // There're some problems about the inherits of `Variable` and `Trackable`. + throw new NotImplementedException(); + } + } + } + + return slot_variables; + } + + public static Trackable get_mapped_trackable(Trackable trackable, IDictionary? object_map) + { + if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) + { + return trackable; + } + else + { + return possible_res; + } + } + + public static string get_full_name(Trackable var) + { + // TODO: This state is not correct, the whole framework need to be updated in the future. + if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) + { + return ""; + } + // skip the check of attribute `_save_slice_info` . + + // TODO: Need to be revised!!! + return ((ResourceVariable)(object)var).Name; + } + + public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) + { + HashSet checkpointed_trackables = new(); + Dictionary> parents = new(); + for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + { + var object_proto = object_graph_proto.Nodes[i]; + // skip the process of registered saver. + if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || + object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) + { + checkpointed_trackables.Add(i); + } + + foreach (var child_proto in object_proto.Children) + { + var child = child_proto.NodeId; + if (!parents.ContainsKey(child)) + { + parents[child] = new HashSet(); + } + + parents[child].Add(i); + } + } + + Queue to_visit = new(checkpointed_trackables.AsEnumerable()); + while (to_visit.Count > 0) + { + var trackable = to_visit.Dequeue(); + if (!parents.ContainsKey(trackable)) continue; + var current_parents = parents[trackable]; + foreach (var parent in current_parents) + { + checkpointed_trackables.Add(parent); + if (parents.ContainsKey(parent)) + { + to_visit.Enqueue(parent); + } + } + parents.Remove(trackable); + } + + // TODO: Complete it after supporting checkpoint. + // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + // { + // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); + // } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs new file mode 100644 index 00000000..d8297ea3 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Checkpoint; + +public record class CheckpointOptions( + string experimental_io_device = null, + bool experimental_enable_async_checkpoint = false); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs new file mode 100644 index 00000000..2ad55448 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Serilog.Debugging; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint; + +public class ObjectGraphView: TrackableView, ICloneable +{ + protected IEnumerable? _attached_dependencies; + // TODO: attached_dependencies + public ObjectGraphView(Trackable root, IEnumerable? attached_dependencies = null): base(root) + { + _attached_dependencies = attached_dependencies; + } + + public object Clone() + { + // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ + return new ObjectGraphView(Root, _attached_dependencies); + } + + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + List res = base.children(obj, save_type) + .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + // Check the reference, not value. + if (obj == Root && _attached_dependencies is not null) + { + res.AddRange(_attached_dependencies); + } + + return res; + } + + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); + } + + public IEnumerable? AttachedDependencies + { + get => _attached_dependencies; + } + + public virtual (List, Dictionary>) breadth_first_traversal() + { + return base._descendants_with_paths(); + } + + // TODO: complete the implementation + public void serialize_object_graph(object? saveables_cache = null) + { + throw new NotImplementedException(); + } + + // TODO: complete the implementation + public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs new file mode 100644 index 00000000..7724c6b7 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -0,0 +1,229 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +public static class SaveUtilV1 +{ + public static (Dictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + IDictionary? object_map = null) + { + // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, + // till now only internal registrations are allowed. So, we won't return a saver in this function. + // The implementation of this function should be updated if tensorflow update it. + Dictionary> checkpoint_factory_map = new(); + foreach (var pair in object_names) + { + var trackable = pair.Key; + var object_name = pair.Value; + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + + // skip the registration process. + + List current_list = new(); + foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) + { + // treat name as key_suffix. + var name = name_and_factory.Key; + var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); + + current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); + } + + checkpoint_factory_map[trackable] = current_list; + } + + return (checkpoint_factory_map, null); + } + + public static (List, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, + IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, + object? saveables_cache = null) + { + + Graph target_context; + if (to_graph is not null) + { + using (to_graph.as_default()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + // tensorflow python: `with ops.device("/cpu:0")` + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + else + { + using (new ops.NullContextManager()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + // tensorflow python: `with ops.device("/cpu:0")` + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + } + + public static (List, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); + var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( + trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, + saveables_cache); + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); + } + + private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList trackable_objects, + IDictionary node_ids, + IDictionary> + slot_variables) + { + TrackableObjectGraph object_graph_proto = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + var trackable = trackable_objects[i]; + Debug.Assert(node_ids[trackable] == i); + TrackableObjectGraph.Types.TrackableObject object_proto; + if (slot_variables.TryGetValue(trackable, out var slots)) + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); + } + else + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(); + } + object_graph_proto.Nodes.Add(object_proto); + foreach (var child in graph_view.list_children(trackable)) + { + object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { NodeId = node_ids[child.Refer], LocalName = child.Name }); + } + } + + return object_graph_proto; + } + + private static (List, object?, object?) add_attributes_to_object_graph(IList trackable_objects, + TrackableObjectGraph object_graph_proto, IDictionary node_ids, + IDictionary object_names, IDictionary object_map, + bool call_with_mapped_captures, object? saveables_cache = null) + { + int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + Debug.Assert(node_ids[trackable_objects[i]] == i); + } + + var (checkpoint_factory_map, unmmaped_registered_savers) = + get_checkpoint_factories_and_keys(object_names, object_map); + + // skip the process of registered savers + + var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, + object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); + return (named_saveable_objects, feed_additions, null); + } + + public static (List, object?) generate_saveable_objects( + IDictionary> checkpoint_factory_map, + TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + List named_saveable_objects = new(); + foreach (var pair in checkpoint_factory_map) + { + var trackable = pair.Key; + var factory_data_list = pair.Value; + bool fill_object_proto = object_graph_proto is not null && node_ids is not null; + TrackableObjectGraph.Types.TrackableObject object_proto = null!; + if (fill_object_proto) + { + object_proto = object_graph_proto.Nodes[node_ids[trackable]]; + } + + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + // skip cache + + foreach (var factory_data in factory_data_list) + { + var name = factory_data.name; + var key = factory_data.checkpoint_key; + var saveable_factory = factory_data.factory; + + // TODO: oneflow python has a process with callable `saveable_factory`. + var maybe_saveable = saveable_factory; + IEnumerable savesbles; + if (maybe_saveable is MySaveableObject) + { + savesbles = new List() { (MySaveableObject)maybe_saveable }; + } + else if (maybe_saveable is Tensor) + { + savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); + } + else + { + throw new TypeError("Unexpected type."); + } + + foreach (var saveable in savesbles) + { + if (!saveable.name.Contains(key)) + { + throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + + $"'{saveable.name}' for attribute '{name}'. Expected a name" + + $" containing '{key}'."); + } + } + + // skip the process of PythonState + + named_saveable_objects.AddRange(savesbles); + + if(!fill_object_proto) continue; + + // skip the process of TrackableSaveable + + object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); + } + } + + return (named_saveable_objects, null); + } +} + +public record class CheckpointFactoryData +( + object factory, + string name, + string checkpoint_key +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs new file mode 100644 index 00000000..7d101d5e --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; + +namespace Tensorflow.Checkpoint; + +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private EagerTensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private void gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + throw new NotImplementedException(); + } + + private (EagerTensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + throw new NotImplementedException(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); + _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); + object_graph_tensor = null; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs new file mode 100644 index 00000000..ed1f3ec4 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -0,0 +1,75 @@ +using System; +using Tensorflow.Train; +using System.Collections.Generic; +using System.IO; + +namespace Tensorflow.Checkpoint; + +public class TrackableView +{ + protected WeakReference _root_ref; + public TrackableView(Trackable obj) + { + _root_ref = new WeakReference(obj); + } + + public TrackableView(WeakReference obj) + { + _root_ref = obj; + } + + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + obj._maybe_initialize_trackable(); + // Note: in python the return type of `Trackable._trackable_children` is not fixed. + // Therefore it uses `convert_to_trackable` to have an extra process. + return obj._trackable_children(save_type); + } + + public Trackable Root + { + get + { + if (_root_ref.TryGetTarget(out Trackable res)) + { + return res; + } + else + { + throw new InvalidDataException( + "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); + } + } + } + + /// + /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. + /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths + /// + protected (List, Dictionary>) _descendants_with_paths() + { + List bfs_sorted = new(); + Queue to_visit = new(); + Dictionary> node_paths = new(); + node_paths[this.Root] = new List(); + while (!to_visit.empty()) + { + var current_trackable = to_visit.Dequeue(); + bfs_sorted.Add(current_trackable); + var children_dict = this.children(current_trackable); + foreach (var name in children_dict.Keys) + { + var dependency = children_dict[name]; + if (!node_paths.ContainsKey(dependency)) + { + var list = new List(node_paths[current_trackable]); + list.Add(new TrackableReference(name, dependency)); + node_paths[dependency] = list; + to_visit.Enqueue(dependency); + } + } + } + + return (bfs_sorted, node_paths); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs new file mode 100644 index 00000000..84ec24cb --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Exceptions; + +public class AssertionError : TensorflowException +{ + public AssertionError() : base() + { + + } + + public AssertionError(string message) : base(message) + { + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 6ce3bf3c..cce13b55 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -304,7 +304,7 @@ namespace Tensorflow } } - private static OpList stripped_op_list_for_graph(GraphDef graph_def) + public static OpList stripped_op_list_for_graph(GraphDef graph_def) { var used_ops = ops_used_by_graph_def(graph_def); @@ -345,5 +345,66 @@ namespace Tensorflow return used_ops.ToArray(); } + + private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) + { + foreach (var attr_def in op_def.Attr) + { + if (attr_def.Name == attr_name) + { + if (attr_def.DefaultValue is null) return false; + // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. + return true; + } + } + + return false; + } + + public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) + { + Dictionary op_name_to_function = new(); + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + op_name_to_function[function_def.Signature.Name] = function_def; + } + + Action _strip_node_default_valued_attrs = (node_def) => + { + if (op_name_to_function.ContainsKey(node_def.Op)) return; + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is null) return; + + HashSet attrs_to_strip = new(); + foreach (var attr in node_def.Attr) + { + if (is_default_attr_value(op_def, attr.Key, attr.Value)) + { + attrs_to_strip.Add(attr.Key); + } + } + + foreach (var attr in attrs_to_strip) + { + node_def.Attr.Remove(attr); + } + }; + + foreach (var node_def in meta_graph_def.GraphDef.Node) + { + _strip_node_default_valued_attrs(node_def); + } + + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + foreach (var function_node_def in function_def.NodeDef) + { + _strip_node_default_valued_attrs(function_node_def); + } + } + + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index c52d0b5f..bac9cedb 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Graphs; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -10,7 +11,7 @@ namespace Tensorflow.Functions /// /// /// - public class ConcreteFunction + public class ConcreteFunction: Trackable { FuncGraph func_graph; ForwardBackwardCall forward_backward; diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index d57097ae..056d15f4 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -1,16 +1,23 @@ using System; +using Tensorflow.Train; namespace Tensorflow { - public class Function + public class Function: Trackable { #pragma warning disable CS0169 // The field 'Function._handle' is never used private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - + + public string Name { get; set; } public Function() { } + + public Function(string name) + { + Name = name; + } } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index e25537d8..fce42850 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -9,7 +9,13 @@ namespace Tensorflow.ModelSaving /// public class SaveOptions { - bool save_debug_info; + public bool save_debug_info = false; + public IList? namespace_white_list { get; set; } = null; + public IDictionary? function_aliases { get; set; } = null; + public string? experimental_io_device { get; set; } = null; + // TODO: experimental + public Object? experimental_variable_polict { get; set; } = null; + public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index ee751acf..d5a32c10 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.Train; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -38,6 +39,11 @@ namespace Tensorflow { return var is ResourceVariable; } + + public static bool is_resource_variable(Trackable var) + { + return var is BaseResourceVariable; + } /// /// Creates a variable handle with information to do shape inference. diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index 9d3e854a..f2597574 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -156,7 +156,7 @@ namespace Tensorflow { /// Nodes[0] is considered the root node. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Nodes { + public pbc::RepeatedField Nodes { get { return nodes_; } } @@ -286,6 +286,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public SavedObject(SavedObject other) : this() { children_ = other.children_.Clone(); + dependencies_ = other.dependencies_.Clone(); slotVariables_ = other.slotVariables_.Clone(); saveableObjects_ = other.saveableObjects_.Clone(); switch (other.KindCase) { @@ -328,6 +329,7 @@ namespace Tensorflow { private static readonly pb::FieldCodec _repeated_children_codec = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); /// /// Objects which this object depends on: named edges in the dependency /// graph. @@ -338,6 +340,11 @@ namespace Tensorflow { public pbc::RepeatedField Children { get { return children_; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Dependencies { + get { return dependencies_; } + } /// Field number for the "slot_variables" field. public const int SlotVariablesFieldNumber = 3; @@ -617,6 +624,7 @@ namespace Tensorflow { return; } children_.Add(other.children_); + dependencies_.Add(other.dependencies_); slotVariables_.Add(other.slotVariables_); saveableObjects_.Add(other.saveableObjects_); switch (other.KindCase) { diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 3aa747c2..93413667 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -198,6 +198,12 @@ namespace Tensorflow { public TrackableObject() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot) { + OnConstruction(); + slotVariables_ = slot; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d2198e37..d8f6314b 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -2,5 +2,20 @@ { public abstract class AutoTrackable : Trackable { + public void _delete_tracking(string name) + { + _maybe_initialize_trackable(); + if (_unconditional_dependency_names.ContainsKey(name)) + { + _unconditional_dependency_names.Remove(name); + for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) + { + if (_unconditional_checkpoint_dependencies[i].Name == name) + { + _unconditional_checkpoint_dependencies.RemoveAt(i); + } + } + } + } } } diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index f985c656..e656fe96 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -351,7 +351,7 @@ namespace Tensorflow /// /// /// - protected IVariableV1 get_slot(IVariableV1 var, string name) + internal IVariableV1 get_slot(IVariableV1 var, string name) { var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; if (named_slots == null) @@ -360,6 +360,11 @@ namespace Tensorflow return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; } + internal IEnumerable get_slot_names() + { + return _slots.Keys; + } + private string _var_key(IVariableV1 var) { return $"{var.Op.graph.graph_key}.{var.Op.name}"; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index c86075f8..6239030b 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -48,4 +48,18 @@ namespace Tensorflow validate_shape: restored_shapes == null && op.shape.IsFullyDefined); } } + + public class NoRestoreSaveable: MySaveableObject + { + public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, + new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) + { + + } + + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + return control_flow_ops.no_op(); + } + } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs new file mode 100644 index 00000000..24c8f2f0 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace Tensorflow; + +public record class AssetInfo +( + List asset_defs, + Dictionary asset_initializers_by_resource, + Dictionary asset_filename_map, + Dictionary asset_index +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs new file mode 100644 index 00000000..6723206c --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -0,0 +1,60 @@ +using System; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; + +namespace Tensorflow; + +public class AugmentedGraphView: ObjectGraphView +{ + // private object _children_cache; + // private object _serialization_cache; + private List _untraces_functions; + public AugmentedGraphView(Trackable root): base(root) + { + _untraces_functions = new(); + } + + public void set_signature(object signature_map, object wrapped_functions) + { + // TODO: cache + list_children(Root); + } + + public override List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + Dictionary children = new(); + foreach (var pair in base.list_children(obj, save_type)) + { + var name = pair.Name; + var child = pair.Refer; + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + + return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + } + + public override (List, Dictionary>) breadth_first_traversal() + { + // TODO: implement it if needed. + return base.breadth_first_traversal(); + } + + public List<(string, Trackable)> list_dependencies(Trackable obj) + { + // TODO: deal with cache. + return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); + } + + public Trackable get_child(Trackable obj, string name) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs new file mode 100644 index 00000000..cb7abada --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -0,0 +1,33 @@ +namespace Tensorflow; + +public static class Constants +{ + public static readonly string ASSETS_DIRECTORY = "assets"; + public static readonly string ASSETS_KEY = "saved_model_assets"; + + public static readonly string DEBUG_DIRECTORY = "debug"; + + public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; + + public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; + + public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; + + public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + + public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; + + public static readonly string MAIN_OP_KEY = "saved_model_main_op"; + + public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; + public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; + + public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; + + public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; + + public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; + + public static readonly string VARIABLES_DIRECTORY = "variables"; + public static readonly string VARIABLES_FILENAME = "variables"; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs new file mode 100644 index 00000000..fa9d6e50 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -0,0 +1,17 @@ +using Tensorflow.Train; + +namespace Tensorflow; + +public class RevivedTypes +{ + /// + /// Create a SavedUserObject from a trackable object. + /// + /// + /// + public static SavedUserObject? serialize(Trackable obj) + { + // TODO: complete the implementation. + return null; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs new file mode 100644 index 00000000..b973fd41 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -0,0 +1,9 @@ +using System; + +namespace Tensorflow; + +public enum SaveType +{ + SAVEDMODEL, + CHECKPOINT +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs new file mode 100644 index 00000000..6a241f0e --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -0,0 +1,299 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class SaveableView +{ + private AugmentedGraphView _augmented_graph_view; + private SaveOptions _options; + private List _trackable_objects; + private List _nodes; + private Dictionary> _node_paths; + private Dictionary _node_ids; + private IDictionary> + _slot_variables; + private Dictionary _object_names; + private List _gradient_functions; // to be completed + private List _gradient_defs; // to be completed + private List _concrete_functions; + private Dictionary _captured_tensor_node_ids; + private Dictionary> _saveable_objects_map; + private Dictionary _obj_to_registered_saver; + + public AugmentedGraphView AugmentedGraphView + { + get => _augmented_graph_view; + } + + public Trackable Root + { + get => _nodes[0]; + } + public List Nodes + { + get => _nodes; + } + public Dictionary NodeIds + { + get => _node_ids; + } + public List GradientDefs + { + get => _gradient_defs; + } + public Dictionary> NodePaths + { + get => _node_paths; + } + public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) + { + _augmented_graph_view = augmented_graph_view; + _options = options; + + (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = + CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); + + // TODO: deal with untraced functions. + + initialize_save_and_restore_functions(); + initialize_nodes_and_concrete_functions(); + + _captured_tensor_node_ids = new(); + } + + private void initialize_save_and_restore_functions() + { + // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. + SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. + _obj_to_registered_saver = new(); + _saveable_objects_map = new(); + } + + private void initialize_nodes_and_concrete_functions() + { + _nodes = _trackable_objects.ConvertAll(x => x); // deep copy + _gradient_functions = new(); + _gradient_defs = new(); + + // TODO: deal with the condition that obj in `_saveable_objects_map`. + // foreach (var obj in _nodes) + // { + // + // } + + foreach (var obj in _nodes) + { + if (obj is ConcreteFunction) + { + _concrete_functions.Add((ConcreteFunction)obj); + } + } + } + + public List get_concrete_resource_initializers() + { + // TODO: complete the implementation. + return new List(); + } + + public (Dictionary, Dictionary, AssetInfo) map_resources() + { + Debug.Assert(!tf.Context.executing_eagerly()); + + Dictionary object_map = new(); + Dictionary tensor_map = new(); + + AssetInfo assetInfo = new(new List(), new Dictionary(), + new Dictionary(), new Dictionary()); + + foreach (var node_id in dependency_sorted_node_ids()) + { + var obj = _nodes[node_id]; + var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); + // TODO: deal with Asset (if obj is Asset) + foreach (var tensor in tensors) + { + _captured_tensor_node_ids[tensor] = node_id; + } + } + + return (object_map, tensor_map, assetInfo); + } + + /// + /// Returns topologically sorted nodes, sorted by dependencies. + /// + public List dependency_sorted_node_ids() + { + Dictionary> dependency_map = new(); + foreach (var node in _nodes) + { + var node_id = _node_ids[node]; + List deps = new(); + + // TODO: deal with captured tensor. + + string node_path; + foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) + { + if (!_node_ids.ContainsKey(dep)) + { + node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + throw new ValueError( + $"Found an untracked dependency. Object {node_path} depends on {dep}, " + + $"but this dependency isn't listed as a child. Please track this child by " + + $"overriding `_trackable_children` or use `._track_trackable`."); + } + deps.Add(_node_ids[dep]); + } + } + + try + { + return TrackableUtils.order_by_dependency(dependency_map); + } + catch (TrackableUtils.CyclicDependencyError err) + { + List pretty_printed_nodes = new(); + List pretty_printed_dependencies = new(); + + foreach (var pair in err.LeftOverDependencyMap) + { + var x = pair.Key; + var deps = pair.Value; + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); + pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); + pretty_printed_dependencies.Add( + $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); + } + + throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + + $"Saving cannot continue until this cycle is resolved." + + $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + + $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); + } + } + + /// + /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph + /// + /// + /// + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index, SaveOptions options) + { + SavedObjectGraph proto = new(); + fill_object_graph_proto(proto); + + // TODO: complete the process of concrete functions. + + int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + var obj = _nodes[i]; + var obj_proto = proto.Nodes[i]; + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), + options); + } + + return proto; + } + + private static void write_object_proto(Trackable obj, SavedObject proto, + IDictionary asset_file_def_index, Func> list_children_fn, SaveOptions options) + { + // skip the process of type Asset + if (resource_variable_ops.is_resource_variable(obj)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is Function) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is ConcreteFunction) + { + // TODO: complete it. + throw new NotImplementedException(); + } + // skip the process of type `_CapturedTensor` and `CapturableResource`. + else + { + var registered_type_proto = RevivedTypes.serialize(obj); + if (registered_type_proto is null) + { + registered_type_proto = new SavedUserObject() + { + Identifier = obj.ObjectIdentifier, + Version = new VersionDef() + { + Producer = 1, + MinConsumer = 1, + BadConsumers = { } + } + }; + } + + proto.UserObject = new SavedUserObject(registered_type_proto); + } + + // TODO: try get the registered_name from `registration`. + } + + public void fill_object_graph_proto(SavedObjectGraph proto) + { + for (int node_id = 0; node_id < _nodes.Count; node_id++) + { + var node = _nodes[node_id]; + Debug.Assert(_node_ids[node] == node_id); + SavedObject object_proto = new(); + if (_slot_variables.TryGetValue(node, out var value)) + { + object_proto.SlotVariables.AddRange(value); + } + // skip the check of type `_CapturedTensor` + foreach (var child in _augmented_graph_view.list_children(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[child.Refer]; + child_proto.LocalName = child.Name; + object_proto.Children.Add(child_proto); + } + + foreach (var pair in _augmented_graph_view.list_dependencies(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[pair.Item2]; + child_proto.LocalName = pair.Item1; + object_proto.Dependencies.Add(child_proto); + } + + if (_saveable_objects_map.ContainsKey(node)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if(_obj_to_registered_saver.ContainsKey(node)) + { + // TODO: complete it. + // We now skip it for the lack of `SavedObject.registered_saver` API. + throw new NotImplementedException(); + } + + proto.Nodes.Add(object_proto); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs new file mode 100644 index 00000000..9a066eed --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -0,0 +1,10 @@ +namespace Tensorflow; + +public static class TagConstants +{ + public static readonly string SERVING = "serve"; + public static readonly string TRAINING = "train"; + public static readonly string EVAL = "eval"; + public static readonly string GPU = "gpu"; + public static readonly string TPU = "tpu"; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs new file mode 100644 index 00000000..bcd3ae05 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class BuilderUtils +{ + public static void copy_assets_to_destination_dir(IDictionary asset_filename_map, + string destination_dir, HashSet? saved_files = null) + { + if (saved_files is null) saved_files = new HashSet(); + + var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); + + // TODO: complete the implementation of this function. + if (asset_filename_map is not null && asset_filename_map.Count > 0) + { + throw new NotImplementedException(); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs new file mode 100644 index 00000000..69235605 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -0,0 +1,256 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Google.Protobuf; +using Tensorflow.Checkpoint; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + private static readonly IEnumerable byte_swappable = new List() + { + dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, + dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, + dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, + TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 + }.Select(x => (int)x); + + public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, + string export_dir, IDictionary? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + { + if (options is null) + { + options = new SaveOptions(); + } + + var saved_model = new Tensorflow.SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + + var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = + _build_meta_graph(obj, signatures, options, meta_graph_def); + saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; + + if (!experimental_skip_checkpoint) + { + Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); + CheckpointOptions ckpt_options = new(options.experimental_io_device); + object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + } + BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); + + if (tf.Context.executing_eagerly()) + { + // tensorflow python has a check of `context.async_wait()` here. + } + + // TODO: deal with `pywrap_saved_model.Save(export_dir)`. + + var saved_model_serialized = saved_model.ToString(); + + // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. + if (true) + { + var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), + tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); + // TODO: add c api and complete the fingerprint def. + var fingerprint_proto = ""; + File.WriteAllText(fingerprint_path, fingerprint_proto); + } + + var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); + File.WriteAllText(path, saved_model.ToString()); + + if (options.save_debug_info) + { + throw new NotImplementedException(); + } + + ops.dismantle_graph(exported_graph); + + return (saved_nodes, node_paths); + } + + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, + Dictionary>) _build_meta_graph(Trackable obj, + IDictionary? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + { + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } + + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } + + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is not null) + { + throw new NotImplementedException(); + } + + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) + { + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } + } + + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } + + private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, + IDictionary signatures, IEnumerable namespace_whitelist, + bool save_custom_gradients) + { + var resource_initializers = saveable_view.get_concrete_resource_initializers(); + var exported_graph = new Graph(); + + Dictionary object_map; + Dictionary tensor_map; + AssetInfo asset_info; + using (var g = exported_graph.as_default()) + { + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) + { + // TODO: trace gradient functions. + } + + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers + } + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + } + + foreach (var obj in object_map.Values) + { + obj._maybe_initialize_trackable(); + } + + var (named_saveable_objects, registered_savers) = + SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); + + // TODO: complete the save of checkpoints with `MultiDeviceSaver`. + + saveable_view.dependency_sorted_node_ids(); + + var graph_def = exported_graph.as_graph_def(true); + graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); + verify_ops(graph_def, namespace_whitelist); + + meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); + meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; + // TODO: add git version. + meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); + meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); + + // TODO: deal with signatures here. + + meta_graph.strip_graph_default_valued_attrs(meta_graph_def); + + if (!BitConverter.IsLittleEndian) + { + swap_function_tensor_content(meta_graph_def); + } + + return (asset_info, exported_graph); + } + + private static void verify_ops(GraphDef graph_def, IEnumerable? namespace_whitelist) + { + return; + // if (namespace_whitelist is null || !namespace_whitelist.Any()) + // { + // return; + // } + + // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. + } + + public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) + { + var functions = meta_graph_def.GraphDef.Library.Function; + foreach (var function in functions) + { + var node_def = function.NodeDef; + foreach (var node in node_def) + { + if (node.Op == "Const") + { + var tensor = node.Attr["value"].Tensor; + byte_swap_tensor_content(tensor); + } + } + } + } + + public static void byte_swap_tensor_content(TensorProto tensor) + { + if (byte_swappable.Contains((int)tensor.Dtype)) + { + var tshape = tensor.TensorShape.Dim; + var tensor_bytes = tensor.TensorContent; + if (tensor_bytes is not null && !tensor_bytes.IsEmpty) + { + long tensor_size = 1; + foreach (var sz in tshape) + { + tensor_size *= sz.Size; + } + + var chunksize = tensor_bytes.Length / tensor_size; + List reversed_bytes = new(); + for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) + { + var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); + reversed_bytes.AddRange(current); + } + tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs new file mode 100644 index 00000000..21272941 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Train; + +namespace Tensorflow; + +public class SignatureMap: Trackable +{ + private Dictionary _signatures; + private Dictionary _concrete_signatures; + + public SignatureMap() + { + _signatures = new(); + } + + public void _add_signature(string name, ConcreteFunction concrete_function) + { + _concrete_signatures[name] = concrete_function; + } + + public void _add_signature(string name, Function concrete_function) + { + _signatures[name] = concrete_function; + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + if (save_type != SaveType.SAVEDMODEL) + { + return new Dictionary(); + } + + Dictionary res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); + foreach (var pair in _concrete_signatures) + { + res[pair.Key] = pair.Value; + } + + return res; + } + + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + // TODO: assert the arg_keywords + signature_map._add_signature(name, func); + } + + return signature_map; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs new file mode 100644 index 00000000..723419f6 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -0,0 +1,52 @@ +using System.IO; +using System.Security.Cryptography.X509Certificates; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + /// + /// Return variables sub-directory, or create one if it doesn't exist. + /// + /// + public static string get_or_create_variables_dir(string export_dir) + { + var variables_dir = get_variables_dir(export_dir); + Directory.CreateDirectory(variables_dir); + return variables_dir; + } + + /// + /// Return variables sub-directory in the SavedModel. + /// + /// + /// + public static string get_variables_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); + } + + /// + /// Return assets sub-directory, or create one if it doesn't exist. + /// + /// + /// + public static string get_or_create_assets_dir(string export_dir) + { + var assets_destination_dir = get_assets_dir(export_dir); + Directory.CreateDirectory(assets_destination_dir); + return assets_destination_dir; + } + + /// + /// Return path to asset directory in the SavedModel. + /// + /// + /// + public static string get_assets_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 3a664788..98cdb274 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -17,12 +17,17 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow { - public class saveable_object_util + public static class saveable_object_util { + public class TrackableSaveable: MySaveableObject + { + + } /// /// Returns the variables and names that will be used for a Saver. /// @@ -121,5 +126,17 @@ namespace Tensorflow return names_to_saveables; } + + public static IDictionary saveable_objects_from_trackable(Trackable obj) + { + // TODO: complete the implementation. + return obj.gather_saveables_for_checkpoint(); + } + + public static bool trackable_has_serialize_to_tensor(Trackable obj) + { + // TODO: implement it. + return false; + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 79d6dca9..dce0be2a 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -14,14 +14,38 @@ limitations under the License. ******************************************************************************/ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.ModelSaving; using static Tensorflow.Binding; namespace Tensorflow.Train { public abstract class Trackable { + /// + /// Corresponding to tensorflow/python/trackable/constants.py + /// + public static class Constants + { + public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; + public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; + public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; + } protected int _self_update_uid; + protected IDictionary _unconditional_dependency_names = + new Dictionary(); + + protected IList _unconditional_checkpoint_dependencies = new List(); + protected IDictionary _self_saveable_object_factories = + new Dictionary(); + public virtual string ObjectIdentifier + { + get => "_generic_user_object"; + } + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -73,10 +97,63 @@ namespace Tensorflow.Train /// /// Initialize dependency management. /// - protected void _maybe_initialize_trackable() + public void _maybe_initialize_trackable() { // _self_unconditional_checkpoint_dependencies = [] _self_update_uid = -1; } + + // TODO: cache + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + _maybe_initialize_trackable(); + return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); + } + + public static Trackable convert_to_trackable(object obj, object? parent = null) + { + if (obj is Trackable) + { + return (Trackable)obj; + } + else + { + throw new NotImplementedException(); + } + } + + public virtual IDictionary deserialization_dependencies(IDictionary children) + { + return new Dictionary(); + } + + public virtual (IDictionary, IDictionary) map_resources( + SaveOptions? save_options) + { + return (new Dictionary(), new Dictionary()); + } + + public virtual List export_to_saved_model_graph(IDictionary? object_map = null, + IDictionary? tensor_map = null, SaveOptions? options = null) + { + var (self_object_map, self_tensor_map) = map_resources(options); + foreach (var pair in self_object_map) + { + object_map.Add(pair); + } + foreach (var pair in self_tensor_map) + { + tensor_map.Add(pair); + } + + return self_tensor_map.Keys.ToList(); + } + + public virtual IDictionary gather_saveables_for_checkpoint() + { + return _self_saveable_object_factories; + } } + + public record class TrackableReference(string Name, Trackable Refer); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs new file mode 100644 index 00000000..99020702 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -0,0 +1,148 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; + +namespace Tensorflow.Training; + +public static class TrackableUtils +{ + public class CyclicDependencyError: System.Exception + { + public IDictionary> LeftOverDependencyMap { get; } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map; + } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); + } + } + private static string _ESCAPE_CHAR = "."; + private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + public static string object_path_to_string(IEnumerable node_path_arr) + { + return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); + } + + public static string escape_local_name(string name) + { + return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); + } + + public static string checkpoint_key(string object_path, string local_name) + { + var key_suffix = escape_local_name(local_name); + if (local_name == SERIALIZE_TO_TENSORS_NAME) + { + key_suffix = ""; + } + + return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; + } + + /// + /// Topologically sorts the keys of a map so that dependencies appear first. + /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + /// + /// + /// + public static List order_by_dependency(IDictionary> dependency_map) + { + Dictionary> reverse_dependency_map = new(); + foreach (var pair in dependency_map) + { + foreach (var dep in pair.Value) + { + if (reverse_dependency_map.ContainsKey(dep)) + { + reverse_dependency_map[dep].Add(pair.Key); + } + else + { + reverse_dependency_map[dep] = new HashSet(); + reverse_dependency_map[dep].Add(pair.Key); + } + } + } + + // Validate that all values in the dependency map are also keys. + var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); + if (unknown_keys.Count() > 0) + { + throw new ValueError( + $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); + } + + // Generate the list sorted by objects without dependencies -> dependencies. + // The returned list will reverse this. + List reversed_dependency_arr = new(); + + Queue to_visit = new(); + foreach (var x in dependency_map.Keys) + { + if (!reverse_dependency_map.ContainsKey(x)) + { + to_visit.Enqueue(x); + } + } + + while (to_visit.Count > 0) + { + var x = to_visit.Dequeue(); + reversed_dependency_arr.Add(x); + foreach (var dep in dependency_map[x].Distinct()) + { + var edges = reverse_dependency_map[dep]; + edges.Remove(x); + if (edges.Count == 0) + { + to_visit.Enqueue(dep); + if (!reverse_dependency_map.Remove(dep)) + { + throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); + } + } + } + } + + if (reverse_dependency_map.Count > 0) + { + Dictionary> leftover_dependency_map = new(); + foreach (var pair in reverse_dependency_map) + { + foreach (var x in pair.Value) + { + if (leftover_dependency_map.ContainsKey(x)) + { + leftover_dependency_map[x].Add(pair.Key); + } + else + { + leftover_dependency_map[x] = new List() { pair.Key }; + } + } + } + + throw new CyclicDependencyError(leftover_dependency_map); + } + + reversed_dependency_arr.Reverse(); + return reversed_dependency_arr; + } + + public static string pretty_print_node_path(IEnumerable paths) + { + if (paths.Count() == 0) + { + return "root object"; + } + else + { + return $"root.{string.Join(".", paths.Select(x => x.Name))}"; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index b270ec57..0a050d0f 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -2,6 +2,7 @@ using System; using Tensorflow.Eager; using Tensorflow.Variables; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 95e8db57..bf5ae7be 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -566,5 +566,23 @@ namespace Tensorflow else throw new NotImplementedException(""); } + + public static bool inside_function() + { + return get_default_graph().building_function; + } + + public static void dismantle_graph(Graph graph) + { + + } + + public class NullContextManager: IDisposable + { + public void Dispose() + { + + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs new file mode 100644 index 00000000..1675fba1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -0,0 +1,31 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Engine; + +public abstract partial class Layer +{ + public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + + public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + + public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + IDictionary children; + if (save_type == SaveType.SAVEDMODEL) + { + // TODO: deal with cache. + children = TrackableSavedModelSaver.trackable_children(cache); + } + else + { + children = new Dictionary(); + } + + return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index ba40b1a2..e95e55d6 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine public bool Built => built; public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; + public bool AutoCast => args.Autocast; + public IRegularizer ActivityRegularizer => args.ActivityRegularizer; /// /// A stateful layer is a layer whose updates are run during inference too, @@ -162,7 +164,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - /// + /// /// protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index c287309d..59f74cd2 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,5 +1,7 @@ using System.Collections.Generic; +using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; namespace Tensorflow.Keras.Engine @@ -18,9 +20,18 @@ namespace Tensorflow.Keras.Engine bool overwrite = true, bool include_optimizer = true, string save_format = "tf", - SaveOptions options = null) + SaveOptions? options = null, + IDictionary? signatures = null, + bool save_traces = true) { - saver.save(this, filepath); + if (save_format != "pb") + { + saver.save(this, filepath); + } + else + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 162d06c5..835f6041 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -35,6 +35,12 @@ namespace Tensorflow.Keras.Engine bool _base_model_initialized; bool stop_training; DataHandler data_handler; + + public OptimizerV2 Optimizer + { + get => optimizer; + set => optimizer = value; + } public Model(ModelArgs args) : base(args) diff --git a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs index 61cec646..f29f2dec 100644 --- a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs +++ b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs @@ -194,6 +194,18 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { OnConstruction(); } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(int nodeId, string nodePath, + global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) + { + OnConstruction(); + nodeId_ = nodeId; + nodePath_ = nodePath; + identifier_ = identifier; + metadata_ = metadata; + version_ = version; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Keras/Protobuf/Versions.cs b/src/TensorFlowNET.Keras/Protobuf/Versions.cs index 40405a5a..ff9a23c6 100644 --- a/src/TensorFlowNET.Keras/Protobuf/Versions.cs +++ b/src/TensorFlowNET.Keras/Protobuf/Versions.cs @@ -74,6 +74,13 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { public VersionDef() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(int producer, int minConsumer) { + OnConstruction(); + producer_ = producer; + minConsumer_ = minConsumer; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs new file mode 100644 index 00000000..ea6853fd --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public static class Constants +{ + /// + /// Namespace used to store all attributes added during serialization. + /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an + /// object loaded from `tf.saved_model.load()`. + /// + public static readonly string KERAS_ATTR = "keras_api"; + /// + /// Keys for the serialization cache. + /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} + /// + public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; + /// + /// Name of Keras metadata file stored in the SavedModel. + /// + public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; + + public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; + public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; + public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; + public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; + public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; + public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; + public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; + + public static readonly IList KERAS_OBJECT_IDENTIFIERS = new List() + { + INPUT_LAYER_IDENTIFIER, + LAYER_IDENTIFIER, + METRIC_IDENTIFIER, + MODEL_IDENTIFIER, + NETWORK_IDENTIFIER, + RNN_LAYER_IDENTIFIER, + SEQUENTIAL_IDENTIFIER + }; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs new file mode 100644 index 00000000..a5f315bb --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Saving.SavedModel; + +public class KerasObjectWrapper +{ + +} + +public class KerasObjectWrapper +{ + public T Item { get; set; } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs new file mode 100644 index 00000000..76453ca0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Google.Protobuf; +using ICSharpCode.SharpZipLib.Zip; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using Tensorflow.IO; +using Tensorflow.Keras.Optimizers; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary? signatures, + SaveOptions? options, bool save_traces = true) + { + if (!overwrite && File.Exists(filepath)) + { + throw new Exception("The file already exists but is not allowed to overwrite it."); + } + + if (save_traces) + { + if(should_skip_serialization(model)) + { + throw new NotImplementedException(); + } + } + + OptimizerV2? orig_optimizer = null; + if (!include_optimizer) + { + orig_optimizer = model.Optimizer; + model.Optimizer = null; + model._delete_tracking("optimizer"); + } + + IList saved_nodes; + IDictionary> node_paths; + // skip two scopes of python + using (KerasSavedModelUtils.keras_option_scope(save_traces)) + { + (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); + } + + var metadata = generate_keras_metadata(saved_nodes, node_paths); + using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, + FileAccess.Write)) + { + var writer = new StreamWriter(f); + writer.Write(metadata.ToString()); + } + + if (!include_optimizer) + { + model.Optimizer = orig_optimizer!; + } + } + + public static SavedMetadata generate_keras_metadata(IList saved_nodes, + IDictionary> node_paths) + { + var metadata = new SavedMetadata(); + for (int i = 0; i < saved_nodes.Count; i++) + { + var node = saved_nodes[i]; + if (node is not Layer) + { + continue; + } + + Layer layer = (Layer)node; + + var path = node_paths[node]; + string node_path; + if (path is null) + { + node_path = "root"; + } + else + { + node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; + } + + ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() + { + NodeId = i, + NodePath = node_path, + Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() + { + Producer = 2, + MinConsumer = 1, + BadConsumers = { } + }, + Identifier = layer.ObjectIdentifier, + Metadata = layer.TrackingMetadata + }; + + } + + return metadata; + } + + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs new file mode 100644 index 00000000..ba0bcc66 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -0,0 +1,19 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool should_skip_serialization(object layer) + { + return false; + } + + public static IDictionary wrap_layer_objects(Layer layer, object serialization_cache) + { + // TODO: process the loss + + return null; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs new file mode 100644 index 00000000..36111a18 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -0,0 +1,40 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Newtonsoft.Json; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public abstract class SavedModelSaver +{ + private Trackable _obj; + public SavedModelSaver(Trackable obj) + { + _obj = obj; + } + + public abstract string ObjectIdentifier { get; } + public abstract string TrackingMetadata { get; } + + public abstract IDictionary objects_to_serialize( + IDictionary serialization_cache); + + public abstract IDictionary functions_to_serialize( + IDictionary serialization_cache); + + public IDictionary trackable_children(IDictionary? serialization_cache) + { + if (!KerasSavedModelUtils.ShouldHaveTraces) + { + return new Dictionary(); + } + + var children = objects_to_serialize(serialization_cache); + + return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) + .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + .ToDictionary(x => x.Key, x => x.Value); + } + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs new file mode 100644 index 00000000..ade8ae73 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -0,0 +1,62 @@ +using System.Collections.Generic; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public class LayerSavedModelSaver: SavedModelSaver +{ + private Layer _obj; + public LayerSavedModelSaver(Layer obj): base(obj) + { + _obj = obj; + } + public override string ObjectIdentifier + { + get => Constants.LAYER_IDENTIFIER; + } + + public override IDictionary objects_to_serialize(IDictionary serialization_cache) + { + throw new System.NotImplementedException(); + } + + public override IDictionary functions_to_serialize(IDictionary serialization_cache) + { + throw new System.NotImplementedException(); + } + + public override string TrackingMetadata + { + get + { + JObject metadata = new JObject(); + metadata["name"] = _obj.Name; + metadata["trainable"] = _obj.Trainable; + // metadata["expects_training_arg"] = _obj._expects_training_arg; + // metadata["dtype"] = policy.serialize(_obj._dtype_policy) + metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); + // metadata["stateful"] = _obj.stateful; + // metadata["must_restore_from_config"] = _obj.must_restore_from_config; + // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; + metadata["autocast"] = _obj.AutoCast; + + metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings + { + // Handle conflicts by using values from obj2 + MergeArrayHandling = MergeArrayHandling.Merge + }); + // skip the check of `input_spec` and `build_input_shape` for the lack of members. + // skip the check of `activity_regularizer` for the type problem. + return metadata.ToString(); + } + } + + public static LayerConfig get_serialized(Layer obj) + { + return generic_utils.serialize_keras_object(obj); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs new file mode 100644 index 00000000..30e89582 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -0,0 +1,33 @@ +using System; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool ShouldHaveTraces { get; internal set; } + + public static SaveOptionsContext keras_option_scope(bool save_traces) + { + var res = new SaveOptionsContext(ShouldHaveTraces); + ShouldHaveTraces = save_traces; + return res; + } +} + +/// +/// Implementation of this class is different with that of python. +/// But it could be used with `using` the same as `with` of python. +/// +public class SaveOptionsContext: IDisposable +{ + public bool _old_value; + public SaveOptionsContext(bool old_value) + { + _old_value = true; + } + + public void Dispose() + { + KerasSavedModelUtils.ShouldHaveTraces = _old_value; + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs new file mode 100644 index 00000000..9d1b3088 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -0,0 +1,60 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; + +namespace TensorFlowNET.Keras.UnitTest; + +// class MNISTLoader +// { +// public MNISTLoader() +// { +// var mnist = new MnistModelLoader() +// +// } +// } + +[TestClass] +public class SaveTest +{ + [TestMethod] + public void Test() + { + var inputs = new KerasInterface().Input((28, 28, 1)); + var x = new Flatten(new FlattenArgs()).Apply(inputs); + x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); + x = new LayersApi().Dense(units: 10).Apply(x); + var outputs = new LayersApi().Softmax(axis: 1).Apply(x); + var model = new KerasInterface().Model(inputs, outputs); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 50000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("", save_format:"pb"); + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 36ff4a3d..56c212d0 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -47,7 +47,7 @@ - + From c4114d5f1815a5281ff2a607dd51e43f17e8a23b Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Mon, 23 Jan 2023 16:25:38 +0800 Subject: [PATCH 03/10] Add more facilities to the saved model framework. --- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 253 ++++++++++++++++++ .../Checkpoint/SaveUtilV1.cs | 18 +- .../Checkpoint/SaveableCompat.cs | 16 ++ .../Checkpoint/TrackableSaver.cs | 109 -------- .../Checkpoint/TrackableView.cs | 8 +- .../Checkpoint/checkpoint.cs | 191 +++++++++++++ .../Checkpoint/functional_saver.cs | 36 +++ .../Protobuf/TrackableObjectGraph.cs | 10 + .../Training/AutoTrackable.cs | 51 +++- .../Training/Saving/SaveSpec.cs | 2 +- .../Training/Saving/SavedModel/save.cs | 52 ++-- .../Saving/saveable_object_util.py.cs | 50 ++++ src/TensorFlowNET.Core/Training/Trackable.cs | 42 ++- 13 files changed, 685 insertions(+), 153 deletions(-) create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs delete mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/checkpoint.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/functional_saver.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs new file mode 100644 index 00000000..dc2a92fb --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -0,0 +1,253 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint +{ + internal record class TrackableData( + // A trackable in the root Trackable object graph. + Trackable trackable, + // The index at which the Trackable appears in TrackableObjectGraph.nodes. + int node_id, + // The BFS-generated path from the root object / used to generate readable checkpoint keys. + string object_name, + // A list of ObjectReference for each child connected to this Trackable. + pbc::RepeatedField children_proto, + // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). + pbc::RepeatedField slot_variable_proto, + // The object to save to checkpoint. Usually this is the same as `trackable`, + // but can differ when the the caller wants to specify a different object to + // save. For example, when saving checkpoints asynchronously, variables are + // copied to the CPU. `object_to_save` is set as the copied variable. + Trackable object_to_save + ); + public static class SaveUtil + { + public static (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) + { + var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); + var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); + + var object_graph_proto = fill_object_graph_proto(trackable_data); + + var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); + var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); + + Dictionary feed_additions; + if(cache is null) + { + feed_additions = null; + serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, + cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); + } + else + { + feed_additions = null; + // TODO: deal with cache. + throw new NotFiniteNumberException(); + } + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + + return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); + } + + private static (List, Dictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach(var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + Dictionary node_ids = new(); + for(int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + List trackable_data = new(); + foreach(var trackable in trackable_objects) + { + pbc::RepeatedField children_proto = new(); + foreach(var child in graph_view.list_children(trackable)) + { + children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { + NodeId = node_ids[child.Refer], + LocalName = child.Name + }); + } + slot_variables.TryGetValue(trackable, out var slot_variable); + trackable_data.Add(new TrackableData( + trackable: trackable, + node_id: node_ids[trackable], + object_name: object_names[trackable], + children_proto: children_proto, + slot_variable_proto: slot_variable??new pbc.RepeatedField(), + object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) + )); + } + return (trackable_data, node_ids); + } + + private static TrackableObjectGraph fill_object_graph_proto(IList trackable_data) + { + TrackableObjectGraph object_graph_proto = new(); + for(int i = 0; i < trackable_data.Count; i++) + { + var td = trackable_data[i]; + Debug.Assert(td.node_id == i); + object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); + } + return object_graph_proto; + } + + /// + /// Creates dictionary of tensors to checkpoint, and updates the proto. + /// + /// + /// + /// + /// + /// + private static IDictionary> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) + { + Dictionary> serialized_tensors = new(); + foreach(var td in tensor_trackables) + { + // TODO: deal with cache. + var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; + var trackable = td.object_to_save; + IDictionary tensor_dict; + if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) + { + (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); + } + else + { + tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + } + if(trackable is not null) + { + serialized_tensors[trackable] = tensor_dict; + } + else + { + serialized_tensors[Trackable.None] = tensor_dict; + } + } + return serialized_tensors; + } + + private static IDictionary get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + var trackable = trackable_data.object_to_save; + + // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. + IDictionary ret_tensor_dict; + if (call_with_mapped_captures) + { + throw new NotImplementedException(); + } + else + { + ret_tensor_dict = trackable.serialize_to_tensors(); + } + + // TODO: revise the types and complete it + Dictionary tensor_dict = new(); + foreach(var pair in ret_tensor_dict) + { + var local_name = TrackableUtils.escape_local_name(pair.Key); + var maybe_tensor = pair.Value; + var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); + + tensor_dict[checkpoint_key] = maybe_tensor; + + if(maybe_tensor is SaveSpec) + { + ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + } + + if(object_graph_proto is not null) + { + object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { + Name = local_name, + CheckpointKey = checkpoint_key, + FullName = CheckPointUtils.get_full_name(trackable) + }); + } + } + return tensor_dict; + } + + /// + /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. + /// + /// + /// + /// + /// + /// + private static (Trackable, IDictionary) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + Dictionary object_names = new(); + object_names[trackable_data.trackable] = trackable_data.object_name; + Dictionary object_map = new(); + object_map[trackable_data.trackable] = trackable_data.object_to_save; + + var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); + var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, + call_with_mapped_captures, saveables_cache: null); + var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); + return (trackable, trackable.serialize_to_tensors()); + } + + private static IDictionary> get_and_write_registered_savers(IDictionary> registered_trackables, TrackableObjectGraph object_graph_proto) + { + Dictionary> registered_savers = new(); + foreach(var pair in registered_trackables) + { + foreach(var td in pair.Value) + { + if (registered_savers.ContainsKey(pair.Key)) + { + registered_savers[pair.Key] = new Dictionary(); + } + else + { + registered_savers[pair.Key][td.object_name] = td.object_to_save; + } + + var object_proto = object_graph_proto.Nodes[td.node_id]; + // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. + } + } + return registered_savers; + } + + private static (IList, IList, IDictionary>) split_trackables(IEnumerable trackable_data) + { + List tensor_trackables = new(); + List py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. + Dictionary> registered_trackables = new(); + + foreach(var td in trackable_data) + { + // TODO: deal with registration. + tensor_trackables.Add(td); + } + return (tensor_trackables, py_state_trackables, registered_trackables); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 7724c6b7..44fa5c5d 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -7,6 +7,7 @@ using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; +using Google.Protobuf; namespace Tensorflow.Checkpoint; @@ -47,19 +48,16 @@ public static class SaveUtilV1 IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { - - Graph target_context; if (to_graph is not null) { - using (to_graph.as_default()) - { - var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + to_graph.as_default(); + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` - var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); - named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); - return (named_saveable_objects, registered_savers); - } + // tensorflow python: `with ops.device("/cpu:0")` + var serialized = graph_proto.ToByteString().ToString(); + var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); } else { diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs new file mode 100644 index 00000000..fa441d79 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint +{ + internal static class SaveableCompat + { + public static string? get_saveable_name(Trackable cls_or_obj) + { + // TODO: implement it with Attribute. + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs deleted file mode 100644 index 7d101d5e..00000000 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs +++ /dev/null @@ -1,109 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Contexts; -using Tensorflow.Eager; - -namespace Tensorflow.Checkpoint; - -public class TrackableSaver -{ - private ObjectGraphView _graph_view; - private EagerTensor _cached_save_operation; - private TrackableObjectGraph _last_save_object_graph; - private Tensor? _object_graph_feed_tensor = null; - private Tensor? _file_prefix_feed_tensor = null; - public TrackableSaver(ObjectGraphView graph_view) - { - _graph_view = graph_view; - - // TODO: cache when not executing eagerly. - // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, - // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` - - } - - private void gather_serialized_tensors(Tensor? object_graph_tensor = null) - { - throw new NotImplementedException(); - } - - private (EagerTensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) - { - throw new NotImplementedException(); - } - - // TODO: parameter write_done_callback - public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, - CheckpointOptions? options = null) - { - if (options is null) - { - options = new CheckpointOptions(); - } - - Dictionary feed_dict = new(); - bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); - if (checkpoint_number is not null) - { - file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; - } - - Tensor file_prefix_tensor; - Tensor object_graph_tensor; - if (use_session) - { - if (_object_graph_feed_tensor is null) - { - // In python there is `with ops.device("/cpu:0")`. - _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); - _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); - } - - object_graph_tensor = _object_graph_feed_tensor; - file_prefix_tensor = _file_prefix_feed_tensor; - feed_dict[file_prefix_tensor] = file_prefix; - } - else - { - // In python there is `with ops.device("/cpu:0")`. - file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); - object_graph_tensor = null; - } - - var (save_path, new_feed_additions) = - save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); - - if (new_feed_additions is not null) - { - foreach (var pair in new_feed_additions) - { - feed_dict.Add(pair.Key, pair.Value); - } - } - if(!use_session) - { - session = null; - } - else if (session is null) - { - session = new Session(); // In python it uses `get_session`. - } - - if (session is not null) - { - var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); - return session.run((Tensor)save_path, s); - } - else if (use_session) - { - throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + - "in graph mode without a default session. Please use " + - "`with tf.Session():` to create a session."); - } - else - { - return save_path; - } - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index ed1f3ec4..6d81d2c9 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -21,9 +21,14 @@ public class TrackableView public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) { obj._maybe_initialize_trackable(); + Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - return obj._trackable_children(save_type); + foreach(var pair in obj._trackable_children(save_type)) + { + children[pair.Key] = pair.Value; + } + return children; } public Trackable Root @@ -50,6 +55,7 @@ public class TrackableView { List bfs_sorted = new(); Queue to_visit = new(); + to_visit.Enqueue(Root); Dictionary> node_paths = new(); node_paths[this.Root] = new List(); while (!to_visit.empty()) diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs new file mode 100644 index 00000000..79109489 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -0,0 +1,191 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Train; +using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +/// +/// Saves and restores a `Trackable` object and its dependencies. +/// +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private Tensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + private Dictionary? _object_map = null; + private object? _cache = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); + + // TODO: cache. + + if(object_graph_tensor is null) + { + // tensorflow python: `with ops.device("/cpu:0"):` + object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + } + else + { + feed_additions[object_graph_tensor] = graph_proto.ToString(); + } + Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + if (serialized_tensors.ContainsKey(Trackable.None)) + { + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + } + return (serialized_tensors, feed_additions, registered_savers, graph_proto); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] { save_op })) + { + _cached_save_operation = array_ops.identity(file_prefix); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] {save_op} )) + { + _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); + object_graph_tensor = null; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs new file mode 100644 index 00000000..759cbd66 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -0,0 +1,36 @@ +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.OptimizerOptions.Types; + +namespace Tensorflow.Checkpoint +{ + /// + /// Saves checkpoints directly from multiple devices. + /// Note that this is a low-level utility which stores Tensors in the keys + /// specified by `SaveableObject`s.Higher-level utilities for object-based + /// checkpointing are built on top of it. + /// + public class MultiDeviceSaver + { + public MultiDeviceSaver(IDictionary> serialized_tensors, + IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) + { + + } + + public Operation? save(string file_prefix, CheckpointOptions? options= null) + { + throw new NotImplementedException(); + } + + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 93413667..fb197eca 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -205,6 +205,16 @@ namespace Tensorflow { slotVariables_ = slot; } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot, + pbc::RepeatedField children + ) + { + OnConstruction(); + slotVariables_ = slot; + children_ = children; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d8f6314b..6f10fd2e 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,4 +1,10 @@ -namespace Tensorflow.Train +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Train { public abstract class AutoTrackable : Trackable { @@ -17,5 +23,48 @@ } } } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + if(save_type != SaveType.SAVEDMODEL) + { + return base._trackable_children(save_type, cache); + } + + Dictionary functions = new(); + // TODO: process of logs. + var properties = this.GetType().GetProperties(); + foreach ( var property in properties ) + { + string name = property.Name; + object value = property.GetValue(this, null); + if(value is Function || value is ConcreteFunction) + { + functions[name] = (Trackable)value; + } + } + + // TODO: process the type `core_types.GenericFunction`. + + Dictionary children = new(); + foreach(var pair in CheckpointDependencies) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) // or Generic function + { + continue; + } + if(functions.ContainsKey(name) && functions[name] != child) + { + throw new ValueError($"Can't save object because it has multiple children with the same " + + $"name. Object: {this}, attribute name: {name}, child 1: " + + $"{child}, child 2: {functions[name]}"); + } + children[name] = child; + } + + return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs index 1ae912ce..393a6a98 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs @@ -28,7 +28,7 @@ namespace Tensorflow public string slice_spec => _slice_spec; private string _name; - public string name => _name; + public string name { get => _name; set => _name = value; } private TF_DataType _dtype; public TF_DataType dtype => _dtype; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index 69235605..cc839952 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -134,35 +134,33 @@ public static partial class SavedModelUtils Dictionary object_map; Dictionary tensor_map; AssetInfo asset_info; - using (var g = exported_graph.as_default()) + exported_graph.as_default(); + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) { - (object_map, tensor_map, asset_info) = saveable_view.map_resources(); - // TODO: deal with signatures. - if (save_custom_gradients) - { - // TODO: trace gradient functions. - } + // TODO: trace gradient functions. + } - foreach (var resource_initializer_function in resource_initializers) - { - // List asset_dependencies = new(); - // TODO: deal with initializers - } - - // using(ops.control_dependencies(...)) - var init_op = control_flow_ops.no_op(); - if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) - { - meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); - } - else - { - meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); - } - // Lack `CopyFrom` API - // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers } - + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + foreach (var obj in object_map.Values) { obj._maybe_initialize_trackable(); @@ -180,11 +178,13 @@ public static partial class SavedModelUtils verify_ops(graph_def, namespace_whitelist); meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef = new(); meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; // TODO: add git version. meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList = new(); meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 98cdb274..622eed3a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -138,5 +138,55 @@ namespace Tensorflow // TODO: implement it. return false; } + + internal static string convert_to_string(string x) + { + return tf.compat.as_str(x); + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private Trackable _obj; + private IList _saveables; + public SaveableCompatibilityConverter(Trackable obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public Trackable Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary serialize_to_tensors() + { + return saveable_objects_to_tensor_dict(_saveables); + } + + /// + /// Converts a list of SaveableObjects to a tensor dictionary. + /// + /// + public static Dictionary saveable_objects_to_tensor_dict(IList saveables) + { + Dictionary tensor_dict = new(); + foreach (var saveable in saveables) + { + foreach(var spec in saveable.specs) + { + var name = saveable_object_util.convert_to_string(spec.name); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + if (!string.IsNullOrEmpty(slice_spec)) + { + throw new NotImplementedException(); + } + else + { + tensor_dict[name] = spec.tensor; + } + } + } + return tensor_dict; + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index dce0be2a..b98075d3 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -34,18 +34,35 @@ namespace Tensorflow.Train public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; } protected int _self_update_uid; - protected IDictionary _unconditional_dependency_names = - new Dictionary(); + protected IDictionary _unconditional_dependency_names; - protected IList _unconditional_checkpoint_dependencies = new List(); + protected IList _unconditional_checkpoint_dependencies; protected IDictionary _self_saveable_object_factories = new Dictionary(); + + private static Trackable _none = new Function(); + /// + /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. + /// The `None` can be any object that inherits `Trackable`. + /// This Property is supposed to be used only internal. + /// + public static Trackable None + { + get + { + return _none; + } + } public virtual string ObjectIdentifier { get => "_generic_user_object"; } - + public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } + public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } + public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } + public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -99,8 +116,9 @@ namespace Tensorflow.Train /// public void _maybe_initialize_trackable() { - // _self_unconditional_checkpoint_dependencies = [] _self_update_uid = -1; + _unconditional_checkpoint_dependencies = new List(); + _unconditional_dependency_names = new Dictionary(); } // TODO: cache @@ -153,6 +171,20 @@ namespace Tensorflow.Train { return _self_saveable_object_factories; } + + /// + /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` + /// if you are defining a custom resource or variable with custom ops. + /// Otherwise, please store the state of your trackable in `tf.Variable` objects + /// and add them to Trackable object hierarchy using `setattr` (for subclasses + /// of `AutoTrackable`) or overriding the `_trackable_children` method. + /// + /// + /// + public virtual IDictionary serialize_to_tensors() + { + throw new NotImplementedException(); + } } public record class TrackableReference(string Name, Trackable Refer); From ddd06ab9b6d1bc18229630d98d5f062658b768a9 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Tue, 24 Jan 2023 18:01:22 +0800 Subject: [PATCH 04/10] Add ListWrapper and ITrackable, and revise implmentations. --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 3 +- .../Operations/NnOps/RNNCell.cs | 3 + src/TensorFlowNET.Core/Training/ITrackable.cs | 12 + src/TensorFlowNET.Core/Training/LayerUtils.cs | 9 + src/TensorFlowNET.Core/Training/Trackable.cs | 50 ++- .../Training/data_structures.cs | 364 ++++++++++++++++++ .../Variables/BaseResourceVariable.cs | 4 +- .../Variables/IVariableV1.cs | 1 + .../Variables/RefVariable.cs | 1 + src/TensorFlowNET.Keras/Engine/Functional.cs | 31 ++ src/TensorFlowNET.Keras/Engine/Model.cs | 11 + .../Saving/SavedModel/layer_serialization.cs | 12 + .../SavedModel/serialized_attributes.cs | 14 + .../Saving/SavedModel/utils.cs | 4 +- 14 files changed, 513 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Core/Training/ITrackable.cs create mode 100644 src/TensorFlowNET.Core/Training/LayerUtils.cs create mode 100644 src/TensorFlowNET.Core/Training/data_structures.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f77b4a86..f1ca5632 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,10 +1,11 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Training; namespace Tensorflow.Keras { - public interface ILayer + public interface ILayer: ITrackable { string Name { get; } bool Trainable { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 04fdc7e5..734f2608 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Operations; +using Tensorflow.Train; using Tensorflow.Util; using static Tensorflow.Binding; @@ -147,5 +148,7 @@ namespace Tensorflow { throw new NotImplementedException(); } + + public Trackable GetTrackable() { throw new NotImplementedException(); } } } diff --git a/src/TensorFlowNET.Core/Training/ITrackable.cs b/src/TensorFlowNET.Core/Training/ITrackable.cs new file mode 100644 index 00000000..e4ef2c8f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/ITrackable.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + public interface ITrackable + { + Trackable GetTrackable(); + } +} diff --git a/src/TensorFlowNET.Core/Training/LayerUtils.cs b/src/TensorFlowNET.Core/Training/LayerUtils.cs new file mode 100644 index 00000000..21141965 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/LayerUtils.cs @@ -0,0 +1,9 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + +} diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index b98075d3..2646fb8d 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -18,11 +18,12 @@ using System; using System.Collections.Generic; using System.Linq; using Tensorflow.ModelSaving; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class Trackable + public abstract class Trackable: ITrackable { /// /// Corresponding to tensorflow/python/trackable/constants.py @@ -40,6 +41,7 @@ namespace Tensorflow.Train protected IDictionary _self_saveable_object_factories = new Dictionary(); + private bool _manual_tracking = true; private static Trackable _none = new Function(); /// @@ -54,6 +56,10 @@ namespace Tensorflow.Train return _none; } } + public Trackable GetTrackable() + { + return this; + } public virtual string ObjectIdentifier { get => "_generic_user_object"; @@ -128,6 +134,48 @@ namespace Tensorflow.Train return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); } + public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false) + { + _maybe_initialize_trackable(); + if (!_manual_tracking) return trackable; + var new_reference = new TrackableReference(name, trackable); + var current_object = _lookupup_dependency(name); + + if(current_object is null) + { + _unconditional_checkpoint_dependencies.Add(new_reference); + _handle_deferred_dependencies(name, trackable); + } + _unconditional_dependency_names[name] = trackable; + return trackable; + } + + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for + /// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are + /// considered fulfilled and so are deleted. + /// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations. + /// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves, + /// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context + /// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs). + /// + /// + /// + public virtual void _handle_deferred_dependencies(string name, Trackable trackable) + { + //_maybe_initialize_trackable(); + //trackable._maybe_initialize_trackable(); + + // TODO: complete the implementation. + } + + public virtual Trackable? _lookupup_dependency(string name) + { + if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; + else return null; + } + public static Trackable convert_to_trackable(object obj, object? parent = null) { if (obj is Trackable) diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs new file mode 100644 index 00000000..4cb78181 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -0,0 +1,364 @@ +using Google.Protobuf; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO.Compression; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow.Functions; +using Tensorflow.Keras; +using Tensorflow.Operations.Activation; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; + +namespace Tensorflow.Training +{ + public class NoDependency + { + public Trackable Value { get; set; } + public NoDependency(Trackable value) + { + Value = value; + } + } + + public abstract class TrackableDataStructure : Trackable + { + private bool _self_trainable; + private List _self_extra_variables; + + public TrackableDataStructure() + { + _self_trainable = true; + _self_extra_variables = new List(); + } + + public abstract IEnumerable Values { get; } + public bool Trainable { get => _self_trainable; set => _self_trainable = value; } + public IEnumerable Layers + { + get + { + List collected = new(); + foreach(var obj in Values) + { + if(obj is ILayer) + { + collected.Add((ILayer)obj); + } + else if(obj is TrackableDataStructure) + { + collected.AddRange((obj as TrackableDataStructure).Layers); + } + } + return collected; + } + } + public IEnumerable TrainableWeights + { + get + { + if (!_self_trainable) + { + return new List(); + } + List trainable_variables = new(); + foreach (var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + foreach(var v in _self_extra_variables) + { + if (v.Trainable) + { + trainable_variables.Add(v); + } + } + return trainable_variables; + } + } + public IEnumerable NonTrainableWeights + { + get + { + var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); + var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); + List non_trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables); + } + } + + if (!_self_trainable) + { + // Return order is all trainable vars, then all non-trainable vars. + List trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables); + } + else + { + return non_trainable_variables.concat(non_trainable_extra_variables); + } + } + } + public IEnumerable Weights => TrainableWeights.Concat(NonTrainableWeights); + public IEnumerable TrainableVariables => TrainableWeights; + public IEnumerable NonTrainableVariables => NonTrainableWeights; + public IEnumerable Variables => Weights; + + // TODO: `losses` property. + + /// + /// Add a dependency on `value`. + /// + /// + /// + protected virtual Trackable _track_value(Trackable value, string name) + { + value = sticky_attribute_assignment(this, name, value); + if(value is IVariableV1) + { + _self_extra_variables.Add(value as IVariableV1); + } + // skip the left process (need to be done in the future). + return value; + } + + protected static Trackable wrap_or_unwrap(NoDependency value) + { + return value.Value; + } + + protected static Trackable wrap_or_unwrap(Trackable value) + { + return value; + } + + protected static Trackable wrap_or_unwrap(IList value) + { + return new ListWrapper(value); + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) + { + value = wrap_or_unwrap(value); + trackable._track_trackable(value, name, true); + return value; + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) + { + var wrapped_value = wrap_or_unwrap(value); + trackable._track_trackable(wrapped_value, name, true); + return wrapped_value; + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList value) + { + var wrapped_value = wrap_or_unwrap(value); + trackable._track_trackable(wrapped_value, name, true); + return wrapped_value; + } + } + + public class ListWrapper : TrackableDataStructure, IList, ICloneable + { + private IList _storage; + private bool _non_append_mutation_value; + private bool _external_modification_value; + private IList _last_wrapped_list_snapshot; + /// + /// + /// + /// The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be + /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save. + public ListWrapper(IList wrapped_list) + { + _storage = wrapped_list; + _non_append_mutation_value = _external_modification_value = false; + _last_wrapped_list_snapshot = new List(_storage); + } + + protected bool NonAppendMuation { + get => _non_append_mutation_value; + set + { + // TODO: deal with `attribute_sentinel`. + _non_append_mutation_value = value; + } + } + + protected bool ExternalModification + { + get => _external_modification_value; + set + { + // TODO: deal with `attribute_sentinel`. + _external_modification_value = value; + } + } + + public override IEnumerable Values => this; + public bool IsReadOnly { get => _storage.IsReadOnly; } + + /// + /// Checks for any changes to the wrapped list not through the wrapper. + /// + private void check_external_modification() + { + if (_external_modification_value || _non_append_mutation_value) return; + if (!_storage.SequenceEqual(_last_wrapped_list_snapshot)) + { + _external_modification_value = true; + } + } + + private void update_snapshot() + { + // TODO: deal with `attribute_sentinel`. + if (_external_modification_value || _non_append_mutation_value) return; + _last_wrapped_list_snapshot = new List(_storage); + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + check_external_modification(); + if (_non_append_mutation_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" + + $", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." + + $"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored."); + } + if (_external_modification_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " + + $"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " + + $"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored."); + } + var children = base._trackable_children(save_type, cache); + + if(save_type == SaveType.SAVEDMODEL) + { + children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); + } + + return children; + } + + private bool has_mutation_or_trackable() + { + return _non_append_mutation_value; + } + + /// + /// Allows storage of non-trackable objects. + /// + /// + /// + /// + protected override Trackable _track_value(Trackable value, string name) + { + try + { + base._track_value(value, name); + } + catch(ValueError ex) + { + value = sticky_attribute_assignment(this, name, value); + } + return value; + } + + public object Clone() + { + var res = new ListWrapper(_storage.Select(x => x).ToList()); + res.NonAppendMuation= _non_append_mutation_value; + res.ExternalModification = _external_modification_value; + return res; + } + + public Trackable this[int index] { + get => _storage[index]; + set + { + // skip the process of `Slice`, maybe support it in the future. + _non_append_mutation_value = true; + _storage[index] = _track_value(value, _name_element(index)); + + update_snapshot(); + } + } + + public int IndexOf(Trackable item) => _storage.IndexOf(item); + + public void Insert(int index, Trackable item) + { + check_external_modification(); + _non_append_mutation_value = true; + _storage.Insert(index, item); + update_snapshot(); + } + + public void RemoveAt(int index) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + _storage.RemoveAt(index); + update_snapshot(); + } + + public int Count { get => _storage.Count; } + + public void Add(Trackable item) + { + check_external_modification(); + _storage.Add(item); + update_snapshot(); + } + + public void Clear() => _storage.Clear(); + + public bool Contains(Trackable item) => _storage.Contains(item); + + public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex); + + public bool Remove(Trackable item) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + var res = _storage.Remove(item); + update_snapshot(); + return res; + } + + public IEnumerator GetEnumerator() => _storage.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); + + protected string _name_element(int index) => $"{index}"; + } +} diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 0a050d0f..4526730f 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -22,7 +22,7 @@ namespace Tensorflow protected bool _in_graph_mode; protected bool _trainable; - public bool trainable => _trainable; + public bool Trainable => _trainable; protected Tensor _initial_value; @@ -166,7 +166,7 @@ namespace Tensorflow /// void variable_accessed(BaseResourceVariable variable) { - if (variable.trainable) + if (variable.Trainable) { foreach (var tape in tf.GetTapeSet()) tape.VariableAccessed(variable as ResourceVariable); diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index f4f716c3..3eb78153 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -46,6 +46,7 @@ namespace Tensorflow Graph Graph { get; } TF_DataType dtype { get; } Shape shape { get; } + bool Trainable { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true); IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 67c12c42..38b5b734 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -56,6 +56,7 @@ namespace Tensorflow public string Name => _variable.name; public Tensor eval() => _variable; + public bool Trainable => _trainable; public RefVariable(object initial_value = null, bool trainable = true, diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 09a31b94..61a8956a 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -20,6 +21,30 @@ namespace Tensorflow.Keras.Engine Dictionary tensor_usage_count; + /// + /// Dictionary of layer dependencies to be included in the checkpoint. + /// + public IDictionary LayerCheckpointDependencies + { + get + { + int weight_layer_index = 0; + Dictionary dependencies = new(); + for(int i = 0; i < Layers.Count; i++) + { + var layer = Layers[i]; + var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); + if(weights.Count > 0) + { + dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; + weight_layer_index++; + } + dependencies[$"layer-{i}"] = layer; + } + return dependencies; + } + } + public Functional(Tensors inputs, Tensors outputs, string name = null) : base(new ModelArgs { @@ -325,5 +350,11 @@ namespace Tensorflow.Keras.Engine return output_tensors; } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) + .ToDictionary(x => x.Key, x => x.Value); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 835f6041..41f7788e 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; +using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -108,5 +109,15 @@ namespace Tensorflow.Keras.Engine return variables; } } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + if(save_type == SaveType.SAVEDMODEL) + { + //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. + } + var children = base._trackable_children(save_type, cache); + return children; + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index ade8ae73..f0ad7450 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -29,6 +29,18 @@ public class LayerSavedModelSaver: SavedModelSaver throw new System.NotImplementedException(); } + /// + /// Generates or retrieves serialized attributes from cache. + /// + /// + protected void get_serialized_attributes(IDictionary serialization_cache) + { + // TODO: deal with cache. + Layer a; + + + } + public override string TrackingMetadata { get diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs new file mode 100644 index 00000000..6a163fec --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + /// + /// Class that tracks and validates all serialization attributes. + /// + public class SerializedAttributes + { + + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index 30e89582..a5d84d67 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -4,7 +4,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static bool ShouldHaveTraces { get; internal set; } + public static bool ShouldHaveTraces { get; internal set; } = true; public static SaveOptionsContext keras_option_scope(bool save_traces) { @@ -23,7 +23,7 @@ public class SaveOptionsContext: IDisposable public bool _old_value; public SaveOptionsContext(bool old_value) { - _old_value = true; + _old_value = old_value; } public void Dispose() From bdca3b5e3d92514a0b816f4a8a81b0864428ebf8 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Tue, 24 Jan 2023 22:03:53 +0800 Subject: [PATCH 05/10] Add serialized attributes. --- .../Training/AutoTrackable.cs | 2 +- .../SavedModel/serialized_attributes.cs | 267 +++++++++++++++++- 2 files changed, 267 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 6f10fd2e..5dd9784f 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -6,7 +6,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class AutoTrackable : Trackable + public class AutoTrackable : Trackable { public void _delete_tracking(string name) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 6a163fec..ff3c7875 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -1,14 +1,279 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Metrics; +using Tensorflow.Train; namespace Tensorflow.Keras.Saving.SavedModel { + // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#, + // Using the name "Attributes" may be quite confusing. /// /// Class that tracks and validates all serialization attributes. /// - public class SerializedAttributes + public abstract class SerializedAttributes { + protected IDictionary _object_dict; + protected IDictionary _function_dict; + protected AutoTrackable _keras_trackable; + protected HashSet _all_functions; + protected HashSet _all_checkpointable_objects; + protected SerializedAttributes() + { + _object_dict= new Dictionary(); + _function_dict= new Dictionary(); + _keras_trackable= new AutoTrackable(); + _all_functions= new HashSet(); + _all_checkpointable_objects= new HashSet(); + } + + protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(checkpointable_objects); + _all_functions = new HashSet(functions); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + /// + /// Returns functions to attach to the root object during serialization. + /// + public IDictionary FunctionsToSerialize + { + get + { + Dictionary functions = new(); + foreach(var pair in Functions) + { + if (_all_functions.Contains(pair.Key)) + { + // TODO: deal with `LayerCall`. + functions[pair.Key] = pair.Value; + } + } + return functions; + } + } + + /// + /// Returns objects to attach to the root object during serialization. + /// + public IDictionary ObjectsToSerialize + { + get + { + var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + objects[Constants.KERAS_ATTR] = _keras_trackable; + return objects; + } + } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + public IDictionary set_and_validate_functions(IDictionary function_dict) + { + foreach(var key in _all_functions) + { + if (function_dict.ContainsKey(key)) + { + // TODO: deal with type `LayerCall`. + var fn = function_dict[key]; + if (fn is not null && (fn is not Function)) + { + throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key})."); + } + _function_dict[key] = fn; + + var tf_fn = fn; // TODO: deal with type `LayerCall`. + + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if(property.Name == key) + { + property.SetValue(_keras_trackable, tf_fn); + break; + } + } + } + else + { + throw new ValueError($"Function {key} missing from serialized function dict."); + } + } + return Functions; + } + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + public IDictionary set_and_validate_objects(IDictionary object_dict) + { + foreach(var key in _all_checkpointable_objects) + { + if (object_dict.ContainsKey(key)) + { + _object_dict[key] = object_dict[key]; + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if (property.Name == key) + { + property.SetValue(_keras_trackable, object_dict[key]); + break; + } + } + } + else + { + throw new ValueError($"Object {key} missing from serialized object dict."); + } + } + return CheckpointableObjects; + } + + /// + /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python). + /// + /// + public static SerializedAttributes Create(Trackable obj) + { + if(obj is Model) + { + return new ModelAttributes(); + } + else if(obj is Metric) + { + return new MetricAttributes(); + } + else if(obj is RNN) + { + return new RNNAttributes(); + } + else if(obj is Layer) + { + return new LayerAttributes(); + } + else + { + throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}"); + } + } + + protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + return (checkpointable_objects ?? (new List()), functions ?? (new List())); + } + } + + // Note that the current implementation still has some potential risks. + // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras". + // However, currently it's just a normal class. + public class CommonEndPoints: SerializedAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if(checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if(functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // TODO: remove the `__call__`. + functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + ); + } + } + + public class LayerAttributes: CommonEndPoints + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + ); + } + } + + public class ModelAttributes: LayerAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively(checkpointable_objects,functions); + } + } + + public class MetricAttributes : SerializedAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "variables" }), + functions + ); + } + } + + public class RNNAttributes: LayerAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "states" }), + functions + ); + } } } From b92b08d6290477150c403711b98778e8cae55425 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Wed, 25 Jan 2023 10:14:15 +0800 Subject: [PATCH 06/10] Implement layer serializations. --- .../Checkpoint/TrackableView.cs | 2 +- src/TensorFlowNET.Core/DisposableObject.cs | 68 ++++++++ .../Saving/SavedModel/AugmentedGraphView.cs | 4 +- .../Training/data_structures.cs | 11 +- .../Variables/BaseResourceVariable.cs | 2 +- .../Variables/RefVariable.cs | 3 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 2 + .../Saving/SavedModel/SaveImpl.cs | 53 ++++++- .../Saving/SavedModel/base_serialization.cs | 7 +- .../Saving/SavedModel/layer_serialization.cs | 39 ++++- .../SavedModel/serialized_attributes.cs | 145 +++++++++--------- .../Saving/SavedModel/utils.cs | 14 ++ 12 files changed, 257 insertions(+), 93 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index 6d81d2c9..69bf76fd 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -24,7 +24,7 @@ public class TrackableView Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - foreach(var pair in obj._trackable_children(save_type)) + foreach (var pair in obj._trackable_children(save_type)) { children[pair.Key] = pair.Value; } diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 3c70739b..7fac3d0f 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -17,6 +17,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Tensorflow.Train; namespace Tensorflow { @@ -90,4 +91,71 @@ namespace Tensorflow Dispose(false); } } + + public abstract class DisposableTrackableObject: Trackable, IDisposable + { + protected IntPtr _handle; + protected bool _disposed; + + protected DisposableTrackableObject() + { } + + protected DisposableTrackableObject(IntPtr handle) + => _handle = handle; + + private void Dispose(bool disposing) + { + if (_disposed) + return; + + //first handle managed, they might use the unmanaged resources. + if (disposing) + { + // dispose managed state (managed objects). + DisposeManagedResources(); + } + + // free unmanaged memory + if (_handle != IntPtr.Zero) + { + // Call the appropriate methods to clean up + // unmanaged resources here. + // If disposing is false, + // only the following code is executed. + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; + } + + // Note disposing has been done. + _disposed = true; + } + + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + + public void Dispose() + { + Dispose(true); + // This object will be cleaned up by the Dispose method. + // Therefore, you should call GC.SupressFinalize to + // take this object off the finalization queue + // and prevent finalization code for this object + // from executing a second time. + GC.SuppressFinalize(this); + } + + ~DisposableTrackableObject() + { + Dispose(false); + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 6723206c..82da2ee9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -23,10 +23,10 @@ public class AugmentedGraphView: ObjectGraphView list_children(Root); } - public override List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) { Dictionary children = new(); - foreach (var pair in base.list_children(obj, save_type)) + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) { var name = pair.Name; var child = pair.Refer; diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index 4cb78181..d4e9c401 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -142,21 +142,26 @@ namespace Tensorflow.Training return value; } - protected static Trackable wrap_or_unwrap(NoDependency value) + public static Trackable wrap_or_unwrap(NoDependency value) { return value.Value; } - protected static Trackable wrap_or_unwrap(Trackable value) + public static Trackable wrap_or_unwrap(Trackable value) { return value; } - protected static Trackable wrap_or_unwrap(IList value) + public static Trackable wrap_or_unwrap(IList value) { return new ListWrapper(value); } + public static Trackable wrap_or_unwrap(IEnumerable value) + { + return new ListWrapper(value.ToList()); + } + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) { value = wrap_or_unwrap(value); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 4526730f..f217a052 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -7,7 +7,7 @@ using static Tensorflow.Binding; namespace Tensorflow { - public class BaseResourceVariable : DisposableObject + public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 38b5b734..7b08f3ea 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -20,11 +20,12 @@ using System; using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; +using Tensorflow.Train; namespace Tensorflow { [Obsolete] - public partial class RefVariable : IVariableV1, IProtoBuf + public partial class RefVariable: Trackable, IVariableV1, IProtoBuf { protected string _name; public string UniqueId => _name; diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index e95e55d6..b9b01dae 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -288,6 +288,8 @@ namespace Tensorflow.Keras.Engine } } + public List Variables => weights; + public virtual LayerArgs get_config() => args; } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index ba0bcc66..7168e25b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -1,5 +1,8 @@ using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using Tensorflow.Training; namespace Tensorflow.Keras.Saving.SavedModel; @@ -10,10 +13,54 @@ public partial class KerasSavedModelUtils return false; } - public static IDictionary wrap_layer_objects(Layer layer, object serialization_cache) + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary serialization_cache) { - // TODO: process the loss + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. - return null; + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + var trainable_variables = layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + var non_trainable_variables = layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + + Dictionary res = new(); + res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); + res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); + res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); + res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary serialization_cache) + { + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 36111a18..a399eaf1 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -17,10 +17,10 @@ public abstract class SavedModelSaver public abstract string ObjectIdentifier { get; } public abstract string TrackingMetadata { get; } - public abstract IDictionary objects_to_serialize( + public abstract IDictionary objects_to_serialize( IDictionary serialization_cache); - public abstract IDictionary functions_to_serialize( + public abstract IDictionary functions_to_serialize( IDictionary serialization_cache); public IDictionary trackable_children(IDictionary? serialization_cache) @@ -32,8 +32,7 @@ public abstract class SavedModelSaver var children = objects_to_serialize(serialization_cache); - return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) - .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index f0ad7450..7a0ddd21 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -19,26 +19,51 @@ public class LayerSavedModelSaver: SavedModelSaver get => Constants.LAYER_IDENTIFIER; } - public override IDictionary objects_to_serialize(IDictionary serialization_cache) + public override IDictionary objects_to_serialize(IDictionary serialization_cache) { - throw new System.NotImplementedException(); + return get_serialized_attributes(serialization_cache).ObjectsToSerialize; } - public override IDictionary functions_to_serialize(IDictionary serialization_cache) + public override IDictionary functions_to_serialize(IDictionary serialization_cache) { - throw new System.NotImplementedException(); + return get_serialized_attributes(serialization_cache).FunctionsToSerialize; } /// /// Generates or retrieves serialized attributes from cache. /// /// - protected void get_serialized_attributes(IDictionary serialization_cache) + protected SerializedAttributes get_serialized_attributes(IDictionary serialization_cache) { // TODO: deal with cache. - Layer a; - + var serialized_attr = SerializedAttributes.Create(_obj); + + // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. + if (KerasSavedModelUtils.should_skip_serialization(_obj)) + { + return serialized_attr; + } + + var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); + + serialized_attr.set_and_validate_objects(object_dict); + serialized_attr.set_and_validate_functions(function_dict); + return serialized_attr; + } + + /// + /// Returns dictionary of serialized attributes. + /// + /// + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary serialization_cache) + { + var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); + + functions["_default_save_signature"] = null; + + return (objects, functions); } public override string TrackingMetadata diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index ff3c7875..804ea1a9 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -17,15 +17,15 @@ namespace Tensorflow.Keras.Saving.SavedModel public abstract class SerializedAttributes { protected IDictionary _object_dict; - protected IDictionary _function_dict; + protected IDictionary _function_dict; protected AutoTrackable _keras_trackable; protected HashSet _all_functions; protected HashSet _all_checkpointable_objects; - protected SerializedAttributes() + private SerializedAttributes() { _object_dict= new Dictionary(); - _function_dict= new Dictionary(); + _function_dict= new Dictionary(); _keras_trackable= new AutoTrackable(); _all_functions= new HashSet(); _all_checkpointable_objects= new HashSet(); @@ -34,25 +34,35 @@ namespace Tensorflow.Keras.Saving.SavedModel protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) { _object_dict = new Dictionary(); - _function_dict = new Dictionary(); + _function_dict = new Dictionary(); _keras_trackable = new AutoTrackable(); _all_checkpointable_objects = new HashSet(checkpointable_objects); _all_functions = new HashSet(functions); } - public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + protected SerializedAttributes((IEnumerable, IEnumerable) objects_and_functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(objects_and_functions.Item1); + _all_functions = new HashSet(objects_and_functions.Item2); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); /// /// Returns functions to attach to the root object during serialization. /// - public IDictionary FunctionsToSerialize + public IDictionary FunctionsToSerialize { get { - Dictionary functions = new(); + Dictionary functions = new(); foreach(var pair in Functions) { if (_all_functions.Contains(pair.Key)) @@ -82,7 +92,7 @@ namespace Tensorflow.Keras.Saving.SavedModel /// Saves function dictionary, and validates dictionary values. /// /// - public IDictionary set_and_validate_functions(IDictionary function_dict) + public IDictionary set_and_validate_functions(IDictionary function_dict) { foreach(var key in _all_functions) { @@ -186,94 +196,87 @@ namespace Tensorflow.Keras.Saving.SavedModel // However, currently it's just a normal class. public class CommonEndPoints: SerializedAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) + base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), + functions.Concat(new string[] { })) { - if(checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if(functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), - // TODO: remove the `__call__`. - functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) - ); + + } + + public CommonEndPoints() : + //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, + // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + base(new string[] { "variables", "trainable_variables"}, + new string[] {}) + { + } } public class LayerAttributes: CommonEndPoints { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public LayerAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), + functions.Concat(new string[] { })) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), - functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) - ); + + } + + public LayerAttributes() : + //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, + // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(new string[] { "non_trainable_variables", "layers" }, + new string[] { }) + { + } } public class ModelAttributes: LayerAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public ModelAttributes(IEnumerable checkpointable_objects, IEnumerable functions): + base(checkpointable_objects, functions) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively(checkpointable_objects,functions); + + } + + public ModelAttributes(): base() + { + } } public class MetricAttributes : SerializedAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public MetricAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects.Concat(new string[] { "variables" }), functions) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "variables" }), - functions - ); + + } + + public MetricAttributes() : + base(new string[] { "variables" }, new string[] {}) + { + } } public class RNNAttributes: LayerAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public RNNAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects, functions.Concat(new string[] {"states"})) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "states" }), - functions - ); + + } + + public RNNAttributes() : + base(new string[] { }, new string[] { "states" }) + { + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index a5d84d67..3054271a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving.SavedModel; @@ -12,6 +14,18 @@ public partial class KerasSavedModelUtils ShouldHaveTraces = save_traces; return res; } + + public static IEnumerable list_all_layers(Layer layer) + { + if(layer is Model) + { + return (layer as Model).Layers; + } + else + { + return new List(layer._flatten_layers(false, false)); + } + } } /// From 83906b8f798d7faa99784da7d66489ca51dae4fd Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Mon, 30 Jan 2023 13:42:51 +0800 Subject: [PATCH 07/10] Add lacked implementations (mainly MultiDeviceSaver). --- .../Checkpoint/CheckpointOptions.cs | 2 +- .../Checkpoint/ObjectGraphView.cs | 9 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 23 +- .../Checkpoint/SaveUtilV1.cs | 27 +- .../Checkpoint/TrackableView.cs | 5 +- .../Checkpoint/checkpoint.cs | 9 +- .../Checkpoint/functional_saver.cs | 515 +++++++++++++++++- .../SavedModel/ISerializedAttributes.cs | 35 ++ .../Training/AutoTrackable.cs | 3 +- .../Saving/SavedModel/AugmentedGraphView.cs | 109 +++- .../Saving/SavedModel/SaveableView.cs | 6 +- .../Training/Saving/SavedModel/save.cs | 16 +- .../SavedModel/signature_serialization.cs | 99 +++- .../Saving/saveable_object_util.py.cs | 156 +++++- src/TensorFlowNET.Core/Training/Trackable.cs | 48 +- .../Training/TrackableUtils.cs | 28 +- .../Training/data_structures.cs | 3 +- .../Variables/BaseResourceVariable.cs | 3 + .../Variables/ResourceVariable.cs | 9 + src/TensorFlowNET.Keras/Engine/Functional.cs | 3 +- .../Engine/Layer.Serialize.cs | 7 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 24 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.cs | 3 +- .../Saving/SavedModel/Save.cs | 9 +- .../Saving/SavedModel/SaveImpl.cs | 4 +- .../Saving/SavedModel/base_serialization.cs | 7 +- .../Saving/SavedModel/layer_serialization.cs | 28 +- .../SavedModel/serialized_attributes.cs | 2 +- test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 4 +- 30 files changed, 1037 insertions(+), 161 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs index d8297ea3..f14b5ce7 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -1,5 +1,5 @@ namespace Tensorflow.Checkpoint; public record class CheckpointOptions( - string experimental_io_device = null, + string? experimental_io_device = null, bool experimental_enable_async_checkpoint = false); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs index 2ad55448..cb01b539 100644 --- a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Serilog.Debugging; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; namespace Tensorflow.Checkpoint; @@ -21,9 +22,9 @@ public class ObjectGraphView: TrackableView, ICloneable return new ObjectGraphView(Root, _attached_dependencies); } - public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) { - List res = base.children(obj, save_type) + List res = base.children(obj, save_type, serialization_cache) .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); // Check the reference, not value. if (obj == Root && _attached_dependencies is not null) @@ -34,9 +35,9 @@ public class ObjectGraphView: TrackableView, ICloneable return res; } - public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) { - return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); + return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); } public IEnumerable? AttachedDependencies diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index dc2a92fb..e646f1f0 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint ); public static class SaveUtil { - public static (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -117,16 +117,16 @@ namespace Tensorflow.Checkpoint /// /// /// - private static IDictionary> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) { - Dictionary> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach(var td in tensor_trackables) { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; var trackable = td.object_to_save; - IDictionary tensor_dict; + IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); @@ -147,12 +147,12 @@ namespace Tensorflow.Checkpoint return serialized_tensors; } - private static IDictionary get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { var trackable = trackable_data.object_to_save; // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. - IDictionary ret_tensor_dict; + IDictionary>> ret_tensor_dict; if (call_with_mapped_captures) { throw new NotImplementedException(); @@ -162,8 +162,8 @@ namespace Tensorflow.Checkpoint ret_tensor_dict = trackable.serialize_to_tensors(); } - // TODO: revise the types and complete it - Dictionary tensor_dict = new(); + // TODO: deal with the type `SaveSpce` (currently it will never be it). + Dictionary>> tensor_dict = new(); foreach(var pair in ret_tensor_dict) { var local_name = TrackableUtils.escape_local_name(pair.Key); @@ -172,9 +172,10 @@ namespace Tensorflow.Checkpoint tensor_dict[checkpoint_key] = maybe_tensor; - if(maybe_tensor is SaveSpec) + if(maybe_tensor.GetValueA() is SaveSpec) { - ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + throw new NotImplementedException(); + //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; } if(object_graph_proto is not null) @@ -198,7 +199,7 @@ namespace Tensorflow.Checkpoint /// /// /// - private static (Trackable, IDictionary) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { Dictionary object_names = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 44fa5c5d..d8e251ec 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -174,25 +174,20 @@ public static class SaveUtilV1 { var name = factory_data.name; var key = factory_data.checkpoint_key; - var saveable_factory = factory_data.factory; - + var maybe_saveable = factory_data.factory; + // TODO: oneflow python has a process with callable `saveable_factory`. - var maybe_saveable = saveable_factory; - IEnumerable savesbles; - if (maybe_saveable is MySaveableObject) - { - savesbles = new List() { (MySaveableObject)maybe_saveable }; - } - else if (maybe_saveable is Tensor) + List saveables = new(); + if (maybe_saveable.DataType == typeof(MySaveableObject)) { - savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); + saveables.Add(maybe_saveable.GetValueB()); } else { - throw new TypeError("Unexpected type."); + saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); } - foreach (var saveable in savesbles) + foreach (var saveable in saveables) { if (!saveable.name.Contains(key)) { @@ -204,11 +199,11 @@ public static class SaveUtilV1 // skip the process of PythonState - named_saveable_objects.AddRange(savesbles); + named_saveable_objects.AddRange(saveables); if(!fill_object_proto) continue; - - // skip the process of TrackableSaveable + + // skip the process of `TrackableSaveable` because of lack of APIs. object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); @@ -221,7 +216,7 @@ public static class SaveUtilV1 public record class CheckpointFactoryData ( - object factory, + Maybe factory, string name, string checkpoint_key ); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index 69bf76fd..f89dc10d 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -2,6 +2,7 @@ using Tensorflow.Train; using System.Collections.Generic; using System.IO; +using Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow.Checkpoint; @@ -18,13 +19,13 @@ public class TrackableView _root_ref = obj; } - public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { obj._maybe_initialize_trackable(); Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - foreach (var pair in obj._trackable_children(save_type)) + foreach (var pair in obj._trackable_children(save_type, cache)) { children[pair.Key] = pair.Value; } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 79109489..c9bee0db 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -33,7 +33,7 @@ public class TrackableSaver } - private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -125,7 +125,7 @@ public class TrackableSaver } Dictionary feed_dict = new(); - bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); if (checkpoint_number is not null) { file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; @@ -133,6 +133,7 @@ public class TrackableSaver Tensor file_prefix_tensor; Tensor object_graph_tensor; + string file_prefix_to_save; if (use_session) { if (_object_graph_feed_tensor is null) @@ -145,16 +146,18 @@ public class TrackableSaver object_graph_tensor = _object_graph_feed_tensor; file_prefix_tensor = _file_prefix_feed_tensor; feed_dict[file_prefix_tensor] = file_prefix; + file_prefix_to_save = ""; } else { // In python there is `with ops.device("/cpu:0")`. file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); object_graph_tensor = null; + file_prefix_to_save = file_prefix; } var (save_path, new_feed_additions) = - save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); if (new_feed_additions is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 759cbd66..c4a03985 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -6,9 +6,254 @@ using Tensorflow.Train; using static Tensorflow.ApiDef.Types; using static Tensorflow.CostGraphDef.Types; using static Tensorflow.OptimizerOptions.Types; +using static Tensorflow.Binding; +using System.Text.RegularExpressions; +using System.Linq; +using Tensorflow.Operations; +using Tensorflow.Training; +using Tensorflow.Graphs; namespace Tensorflow.Checkpoint { + /// + /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. + /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. + /// + public interface IFunctionHolder + { + int ArgCount { get; } + object DynamicInvoke(params object[] args); + } + internal record class FunctionHolder(Func Func): IFunctionHolder + { + public int ArgCount => 0; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 1; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 2; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 3; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + public class Maybe + { + private TA? _valueA = default(TA); + private TB? _valueB = default(TB); + private Type _type; + private bool _assigned = false; + public Maybe(TA value) + { + _valueA = value; + _type= typeof(TA); + _assigned = true; + } + public Maybe(TB value) + { + _valueB = value; + _type = typeof(TB); + _assigned = true; + } + + public Type DataType => _type; + + public TA GetValueA() + { + if(!_assigned || DataType != typeof(TA)) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + return _valueA; + } + public TB GetValueB() + { + if (!_assigned || DataType != typeof(TB)) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + return _valueB; + } + public object GetValue() + { + if (!_assigned) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + if(DataType == typeof(TA) && _valueA is not null) + { + return _valueA; + } + else if(DataType == typeof(TB) && _valueB is not null) + { + return _valueB; + } + else if(DataType == typeof(TA)) + { + return _valueA; + } + else + { + return _valueB; + } + } + + public static implicit operator Maybe(TA a) + { + return new Maybe(a); + } + public static implicit operator Maybe(TB b) + { + return new Maybe(b); + } + } + internal class SingleDeviceSaver + { + private IDictionary>> _tensor_slice_dict; + public SingleDeviceSaver(IDictionary>> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict; + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensors = new(); + List slice_specs = new(); + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.DataType == typeof(SaveSpec)) + { + var spec = maybe_tensor.GetValueB(); + var tensor_value = spec.tensor; + if (tensor_value is not null) + { + tensor_names.Add(spec.name); + tensors.Add(tensor_value); + slice_specs.Add(spec.slice_spec); + } + } + else + { + var tensor = maybe_tensor.GetValueA(); + tensor_names.Add(checkpoint_key); + tensors.Add(tensor); + slice_specs.Add(slice_spec); + } + } + } + // TODO: specify the device. + return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); + } + + public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); + + public IDictionary> restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensor_dtypes = new(); + List slice_specs = new(); + + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.DataType == typeof(SaveSpec)) + { + var spec = maybe_tensor.GetValueB(); + tensor_dtypes.Add(spec.dtype); + slice_specs.Add(spec.slice_spec); + tensor_names.Add(spec.name); + } + else + { + var tensor = maybe_tensor.GetValueA(); + tensor_dtypes.Add(tensor.dtype); + slice_specs.Add(slice_spec); + tensor_names.Add(checkpoint_key); + } + } + } + + string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; + + // tf python has code `with ops.device(restore_device):` here. + tf.device(restore_device); // may be risky. + var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + + Dictionary> restored_tensor_dict = new(); + int idx = 0; + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice_spec in tensor_slices.Keys) + { + var restored_tensor = restored_tensors[idx++]; + if (!restored_tensor_dict.ContainsKey(checkpoint_key)) + { + restored_tensor_dict[checkpoint_key] = new Dictionary(); + } + restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; + } + } + return restored_tensor_dict; + } + + public IDictionary> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); + } /// /// Saves checkpoints directly from multiple devices. /// Note that this is a low-level utility which stores Tensors in the keys @@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint /// public class MultiDeviceSaver { - public MultiDeviceSaver(IDictionary> serialized_tensors, + private Dictionary _single_device_savers; + private IDictionary _registered_savers; + private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; + /// + /// + /// + /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. + /// + /// + public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { + _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); + _restore_fn_to_keys = new Dictionary>(); + Dictionary>> tensors_by_device= new(); + + foreach(var pair in serialized_tensors) + { + var obj = pair.Key; + var tensor_dict = pair.Value; + IFunctionHolder restore_fn; + if(obj is null) + { + restore_fn = new FunctionHolder(() => null); + } + else + { + restore_fn = null; + // TODO: implement obj._restore_from_tensors + } + + foreach(var item in tensor_dict) + { + var checkpoint_key = item.Key; + IDictionary spec_to_tensor; + if(item.Value.DataType != typeof(IDictionary)) + { + spec_to_tensor = new Dictionary(); + spec_to_tensor[""] = item.Value.GetValueA(); + } + else + { + spec_to_tensor = item.Value.GetValueB(); + } + + foreach(var spec in spec_to_tensor) + { + var slice_spec = spec.Key; + var tensor = spec.Value; + if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) + { + throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + + $"slice spec. This is invalid because one will overwrite the " + + $"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); + } + _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; + _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); + + // skip the process of device name because lack of API. + var host_device = tensor.Device; + var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>()); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + internal_dict[checkpoint_key] = new Dictionary(); + } + internal_dict[checkpoint_key][slice_spec] = tensor; + } + } + } + + _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); + _registered_savers = new Dictionary(); + if(registered_savers is not null && registered_savers.Count > 0) + { + // TODO: complete the implementation. + throw new NotImplementedException(); + } } - public Operation? save(string file_prefix, CheckpointOptions? options= null) + public Operation save(string file_prefix, CheckpointOptions? options= null) { - throw new NotImplementedException(); + if(options is null) + { + options = new CheckpointOptions(); + } + + tf.device("CPU"); // may be risky. + // TODO: optimize the implementation with new APIs adding to `string_ops`. + string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; + var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + + Operation save_fn() + { + List saved_prefixes= new(); + foreach(var saver in _registered_savers) + { + // TODO: implementi it later. + throw new NotImplementedException(); + } + + int num_shards = _single_device_savers.Count; + List sharded_saves = new(); + var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); + string? last_device = null; + int shard = 0; + foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) + { + var device = pair.Key; + var saver = pair.Value; + last_device = device; + // skip the extra process of device name because of lack of API. + tf.device(device); + var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + saved_prefixes.Add(shard_prefix); + sharded_saves.Add(saver.save(shard_prefix, options)); + } + using (var controller = ops.control_dependencies(sharded_saves.ToArray())) + { + string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; + tf.device(merge_device); + return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); + } + } + + if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return save_fn(); + } } - public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); + + public IDictionary restore(string file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + IDictionary restore_func() + { + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary restore_ops = new(); + + foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) + { + var device = single_saver.Key; + var saver = single_saver.Value; + tf.device(device); + var restored_tensor_dict = saver.restore(file_prefix, options); + + foreach(var pair in restored_tensor_dict) + { + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach(var item in slice_and_tensor) + { + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) + { + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = new Maybe>(dict); + } + else + { + internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; + } + } + else + { + internal_dict[checkpoint_key] = new Maybe>(tensor); + } + restore_fn_input_count[restore_fn]--; + + if (restore_fn_input_count[restore_fn] == 0) + { + Dictionary>> restored_tensors = new(); + foreach(var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if(ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } + } + } + } + } + + foreach(var item in _registered_savers) + { + throw new NotImplementedException(); + } + return restore_ops; + } + + // TODO: complete the implementation. Currently skip it because of lack of API. + bool has_custom_device_saver = false; + + if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return restore_func(); + } + } + + /// + /// Serializes to a SaverDef referencing the current graph. + /// + public SaverDef to_proto() + { + var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); + var save_tensor = _traced_save(filename_tensor); + var restore_op = _traced_restore(filename_tensor).op; + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.name, + Version = SaverDef.Types.CheckpointFormatVersion.V2 + }; + } + + [AutoGraph] + private Tensor _traced_save(Tensor file_prefix) + { + var save_op = save(file_prefix.StringData()[0]); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[]{ save_op })) + { + return array_ops.identity(file_prefix); + } + } + + [AutoGraph] + private Tensor _traced_restore(Tensor file_prefix) + { + var restore_op = restore(file_prefix.StringData()[0]); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[] { restore_op })) + { + return array_ops.identity(file_prefix); + } + } + + private static Tensor registered_saver_filename(string filename, string saver_name) + { + return tf.constant($"{filename}-{saver_name}"); + } + private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) { - throw new NotImplementedException(); + return filename_tensor; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs new file mode 100644 index 00000000..ae8a1ab1 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public interface ISerializedAttributes + { + IDictionary Functions { get; } + + IDictionary CheckpointableObjects { get; } + + /// + /// Returns functions to attach to the root object during serialization. + /// + IDictionary FunctionsToSerialize { get; } + + /// + /// Returns objects to attach to the root object during serialization. + /// + IDictionary ObjectsToSerialize{get; } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + IDictionary set_and_validate_functions(IDictionary function_dict); + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + IDictionary set_and_validate_objects(IDictionary object_dict); + } +} diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 5dd9784f..4d5a664e 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -24,7 +25,7 @@ namespace Tensorflow.Train } } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if(save_type != SaveType.SAVEDMODEL) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 82da2ee9..97162651 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -4,57 +4,130 @@ using Tensorflow.Train; using System.Collections.Generic; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow; public class AugmentedGraphView: ObjectGraphView { - // private object _children_cache; - // private object _serialization_cache; + private Dictionary> _children_cache; + private Dictionary> _serialization_cache; private List _untraces_functions; + private Dictionary _wrapped_functions; public AugmentedGraphView(Trackable root): base(root) { - _untraces_functions = new(); + _children_cache= new Dictionary>(); + _serialization_cache = new Dictionary>(); + _untraces_functions = new List(); + _wrapped_functions = new Dictionary(); } - public void set_signature(object signature_map, object wrapped_functions) + public void set_signature(SignatureMap signature_map, IDictionary wrapped_functions) { - // TODO: cache list_children(Root); + var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; + if (!_children_cache.ContainsKey(Root)) + { + _children_cache[Root] = new Dictionary(); + } + _children_cache[Root][name] = signature_map; + _wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); } - public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary>? serialization_cache = null) { - Dictionary children = new(); - foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) + if(serialization_cache is not null) + { + throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); + } + + if (!_children_cache.ContainsKey(obj)) + { + Dictionary children = new Dictionary(); + _children_cache[obj] = children; + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) + { + child = maybe_uncache_variable_captures((ConcreteFunction)child); + } + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + } + + List res = new(); + foreach(var pair in _children_cache[obj]) { - var name = pair.Name; - var child = pair.Refer; - children[name] = child; + res.Add(new TrackableReference(pair.Key, pair.Value)); } - if (obj is Function && children.Count == 0) + return res; + } + + private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) + { + if (_wrapped_functions.ContainsKey(concrete_function)) { - _untraces_functions.Add(((Function)obj).Name); + return _wrapped_functions[concrete_function]; } + // skip the process here because of lack of feature. + // In the future, we may add an attribute which could specify if the variable is supposed to be cached. + //foreach(var capture in concrete_function.CapturedInputs) + //{ - return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + //} + return concrete_function; } public override (List, Dictionary>) breadth_first_traversal() { - // TODO: implement it if needed. + Trackable get_merged_trackable(Trackable x) + { + // TODO: complete it with new definitions `Asset` and `TrackableConstant`. + return x; + } + var trackable_objects = base.breadth_first_traversal(); + + foreach(var obj in _children_cache.Keys) + { + // skip the deletion of cache (maybe do it later). + foreach(var pair in _children_cache[obj]) + { + _children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); + } + } + return base.breadth_first_traversal(); } public List<(string, Trackable)> list_dependencies(Trackable obj) { - // TODO: deal with cache. - return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); + IDictionary children; + if (!_children_cache.ContainsKey(obj)) + { + children= new Dictionary(); + } + else + { + children= _children_cache[obj]; + } + List<(string, Trackable)> res = new(); + foreach(var pair in obj.deserialization_dependencies(children)) + { + res.Add((pair.Key, pair.Value)); + } + return res; } public Trackable get_child(Trackable obj, string name) { - throw new NotImplementedException(); + return _children_cache[obj][name]; } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6a241f0e..6700e277 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -141,16 +141,16 @@ public class SaveableView foreach (var node in _nodes) { var node_id = _node_ids[node]; - List deps = new(); + List deps = new List(); + dependency_map.Add(node_id, deps); // TODO: deal with captured tensor. - string node_path; foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) { if (!_node_ids.ContainsKey(dep)) { - node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); throw new ValueError( $"Found an untracked dependency. Object {node_path} depends on {dep}, " + $"but this dependency isn't listed as a child. Please track this child by " + diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index cc839952..f3f273b8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -24,7 +24,7 @@ public static partial class SavedModelUtils }.Select(x => (int)x); public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, - string export_dir, IDictionary? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) { if (options is null) { @@ -41,9 +41,9 @@ public static partial class SavedModelUtils if (!experimental_skip_checkpoint) { - Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); + SavedModelUtils.get_or_create_variables_dir(export_dir); CheckpointOptions ckpt_options = new(options.experimental_io_device); - object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); } BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); @@ -67,7 +67,7 @@ public static partial class SavedModelUtils } var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); - File.WriteAllText(path, saved_model.ToString()); + File.WriteAllBytes(path, saved_model.ToByteArray()); if (options.save_debug_info) { @@ -81,7 +81,7 @@ public static partial class SavedModelUtils private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, Dictionary>) _build_meta_graph(Trackable obj, - IDictionary? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { if (ops.inside_function()) { @@ -95,9 +95,9 @@ public static partial class SavedModelUtils } AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); - if (signatures is not null) + if (signatures is null) { - throw new NotImplementedException(); + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); } // TODO: process of aignatures and wrapped_functions @@ -125,7 +125,7 @@ public static partial class SavedModelUtils } private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, - IDictionary signatures, IEnumerable namespace_whitelist, + ConcreteFunction signatures, IEnumerable namespace_whitelist, bool save_custom_gradients) { var resource_initializers = saveable_view.get_concrete_resource_initializers(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 21272941..0d34907f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -1,15 +1,84 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; namespace Tensorflow; +public static class SignatureSerializationUtils +{ + internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; + internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; + internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + Debug.Assert(func is ConcreteFunction); + // TODO: assert the `func.structured_outputs` and arg_keywords. + signature_map._add_signature(name, (ConcreteFunction)func); + } + + return signature_map; + } + + public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) + { + var children = graph_view.list_children(graph_view.Root); + List possible_signatures = new(); + foreach (var item in children) + { + var name = item.Name; + var child = item.Refer; + if(child is not (Function or ConcreteFunction)) + { + continue; + } + if(name == DEFAULT_SIGNATURE_ATTR) + { + Debug.Assert(child is ConcreteFunction); + return (ConcreteFunction)child; + } + ConcreteFunction concrete = get_signature(child); + if(concrete is not null && valid_signature(concrete)) + { + possible_signatures.Add(concrete); + } + } + + if(possible_signatures.Count == 1) + { + var signature = get_signature(possible_signatures[0]); + if(signature is not null && valid_signature(signature)) + { + return signature; + } + } + return null; + } + + private static ConcreteFunction get_signature(Trackable function) + { + // TODO: implement it. + return null; + } + + private static bool valid_signature(ConcreteFunction concreate_function) + { + // TODO: implement it. + return false; + } +} + public class SignatureMap: Trackable { - private Dictionary _signatures; - private Dictionary _concrete_signatures; + private Dictionary _signatures; public SignatureMap() { @@ -18,7 +87,7 @@ public class SignatureMap: Trackable public void _add_signature(string name, ConcreteFunction concrete_function) { - _concrete_signatures[name] = concrete_function; + _signatures[name] = concrete_function; } public void _add_signature(string name, Function concrete_function) @@ -26,33 +95,13 @@ public class SignatureMap: Trackable _signatures[name] = concrete_function; } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if (save_type != SaveType.SAVEDMODEL) { return new Dictionary(); } - Dictionary res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); - foreach (var pair in _concrete_signatures) - { - res[pair.Key] = pair.Value; - } - - return res; - } - - public static SignatureMap create_signature_map(IDictionary signatures) - { - var signature_map = new SignatureMap(); - foreach (var pair in signatures) - { - var name = pair.Key; - var func = pair.Value; - // TODO: assert the arg_keywords - signature_map._add_signature(name, func); - } - - return signature_map; + return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 622eed3a..7066b366 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -16,18 +16,38 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; using Tensorflow.Train; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow { - public static class saveable_object_util + /// + /// A SaveableObject that defines `Trackable` checkpointing steps. + /// + public class TrackableSaveable : MySaveableObject { - public class TrackableSaveable: MySaveableObject + private string _prefix; + private IEnumerable _local_names; + private Trackable _trackable; + private bool _call_with_mapped_captures; + // TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. + public TrackableSaveable(Trackable obj, IEnumerable specs, string name, IEnumerable local_names, + string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) { - + _prefix = prefix; + _trackable = obj; + _local_names = local_names; + _call_with_mapped_captures = call_with_mapped_captures; } + + // TODO: complete this class. + } + public static class saveable_object_util + { /// /// Returns the variables and names that will be used for a Saver. /// @@ -57,7 +77,7 @@ namespace Tensorflow } /// - /// Create `SaveableObject`s from an operation. + /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// /// /// @@ -79,6 +99,74 @@ namespace Tensorflow } } + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Trackable obj, string name) + { + // The `op` maybe `Variable` or `Trackable`. + if (obj is BaseResourceVariable) + { + var variable = obj as BaseResourceVariable; + if (variable.InGraphMode) + { + yield return new ResourceVariableSaveable(variable.GraphElement, "", name); + } + else + { + Debug.Assert(variable is ResourceVariable); + yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); + } + } + else + { + foreach(var pair in saveable_objects_from_trackable(obj)) + { + var attr = pair.Key; + var factory = pair.Value; + string full_name; + if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) + { + full_name = name; + } + else + { + full_name = name + "_" + attr; + } + if(factory.DataType == typeof(ResourceVariable)) + { + var variable = factory.GetValueA(); + foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) + { + yield return op; + } + } + else + { + var variable = factory.GetValueB(); + foreach (var op in saveable_objects_for_op(variable, variable.name)) + { + yield return op; + } + } + } + } + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(MySaveableObject obj, string name) + { + yield return obj; + } + public static Dictionary op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) { op_list = op_list.OrderBy(x => x.Name).ToArray(); @@ -127,16 +215,55 @@ namespace Tensorflow return names_to_saveables; } - public static IDictionary saveable_objects_from_trackable(Trackable obj) + public static IDictionary> saveable_objects_from_trackable(Trackable obj) { - // TODO: complete the implementation. - return obj.gather_saveables_for_checkpoint(); + // skip the process of type `PythonState` + + if (trackable_has_serialize_to_tensor(obj)) + { + var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; + // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. + var tensor_dict = obj.serialize_to_tensors(); + + List specs = new(); + List local_names = new(); + string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; + foreach(var pair in tensor_dict) + { + var tensor_name = pair.Key; + var maybe_tensor = pair.Value; + local_names.Add(tensor_name); + string spec_name = name + TrackableUtils.escape_local_name(tensor_name); + + IDictionary internal_dict; + if(maybe_tensor.DataType == typeof(Tensor)) + { + internal_dict= new Dictionary(); + internal_dict[""] = maybe_tensor.GetValueA(); + } + else + { + internal_dict = maybe_tensor.GetValueB(); + } + + foreach(var item in internal_dict) + { + specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); + } + } + Dictionary> res = new(); + res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); + return res; + } + else + { + return obj.gather_saveables_for_checkpoint(); + } } public static bool trackable_has_serialize_to_tensor(Trackable obj) { - // TODO: implement it. - return false; + return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); } internal static string convert_to_string(string x) @@ -158,27 +285,28 @@ namespace Tensorflow public Trackable Obj => _obj; public IList mySaveables=> _saveables; - public override IDictionary serialize_to_tensors() + public override IDictionary>> serialize_to_tensors() { - return saveable_objects_to_tensor_dict(_saveables); + return saveable_object_to_tensor_dict(_saveables); } /// /// Converts a list of SaveableObjects to a tensor dictionary. /// /// - public static Dictionary saveable_objects_to_tensor_dict(IList saveables) + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) { - Dictionary tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { foreach(var spec in saveable.specs) { + // skip the check that if `spec` is callable. var name = saveable_object_util.convert_to_string(spec.name); var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { - throw new NotImplementedException(); + tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; } else { diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 2646fb8d..a677044a 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -16,7 +16,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; using Tensorflow.Training; using static Tensorflow.Binding; @@ -39,8 +42,8 @@ namespace Tensorflow.Train protected IList _unconditional_checkpoint_dependencies; - protected IDictionary _self_saveable_object_factories = - new Dictionary(); + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); private bool _manual_tracking = true; private static Trackable _none = new Function(); @@ -94,9 +97,13 @@ namespace Tensorflow.Train // assign again. It will add this variable to our dependencies, and if there // is a non-trivial restoration queued, it will handle that. This also // handles slot variables. - if (!args.Overwrite || new_variable is RefVariable) - return _track_checkpointable(new_variable, name: args.Name, - overwrite: args.Overwrite); + if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) + { + var temp = new_variable as Trackable; + var res = _track_trackable(temp, args.Name, args.Overwrite); + Debug.Assert(res is IVariableV1); + return res as IVariableV1; + } else return new_variable; } @@ -122,13 +129,16 @@ namespace Tensorflow.Train /// public void _maybe_initialize_trackable() { + if(_unconditional_checkpoint_dependencies is not null) + { + return; + } _self_update_uid = -1; _unconditional_checkpoint_dependencies = new List(); _unconditional_dependency_names = new Dictionary(); } - // TODO: cache - public virtual IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary>? cache) { _maybe_initialize_trackable(); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); @@ -139,8 +149,8 @@ namespace Tensorflow.Train _maybe_initialize_trackable(); if (!_manual_tracking) return trackable; var new_reference = new TrackableReference(name, trackable); - var current_object = _lookupup_dependency(name); - + var current_object = _lookup_dependency(name); + if(current_object is null) { _unconditional_checkpoint_dependencies.Add(new_reference); @@ -170,7 +180,7 @@ namespace Tensorflow.Train // TODO: complete the implementation. } - public virtual Trackable? _lookupup_dependency(string name) + public virtual Trackable? _lookup_dependency(string name) { if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; else return null; @@ -199,8 +209,8 @@ namespace Tensorflow.Train return (new Dictionary(), new Dictionary()); } - public virtual List export_to_saved_model_graph(IDictionary? object_map = null, - IDictionary? tensor_map = null, SaveOptions? options = null) + public virtual List export_to_saved_model_graph(IDictionary object_map, + IDictionary tensor_map, SaveOptions? options = null) { var (self_object_map, self_tensor_map) = map_resources(options); foreach (var pair in self_object_map) @@ -215,9 +225,17 @@ namespace Tensorflow.Train return self_tensor_map.Keys.ToList(); } - public virtual IDictionary gather_saveables_for_checkpoint() + public virtual IDictionary> gather_saveables_for_checkpoint() { - return _self_saveable_object_factories; + if (saveable_object_util.trackable_has_serialize_to_tensor(this)) + { + // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). + throw new NotImplementedException(); + } + else + { + return _self_saveable_object_factories; + } } /// @@ -229,7 +247,7 @@ namespace Tensorflow.Train /// /// /// - public virtual IDictionary serialize_to_tensors() + public virtual IDictionary>> serialize_to_tensors() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 99020702..390d95c7 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using Tensorflow.Exceptions; using Tensorflow.Train; @@ -22,7 +23,7 @@ public static class TrackableUtils private static string _ESCAPE_CHAR = "."; private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; - private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; public static string object_path_to_string(IEnumerable node_path_arr) { return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); @@ -145,4 +146,27 @@ public static class TrackableUtils return $"root.{string.Join(".", paths.Select(x => x.Name))}"; } } + + /// + /// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. + /// + /// + /// + /// + public static string extract_local_name(string key, string? prefix = null) + { + if(prefix is null) + { + prefix = ""; + } + var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; + try + { + return key.Substring(key.IndexOf(search_key) + search_key.Length); + } + catch(ArgumentOutOfRangeException) + { + return key; + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index d4e9c401..6e3336c9 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -9,6 +9,7 @@ using System.Runtime.InteropServices; using System.Text; using Tensorflow.Functions; using Tensorflow.Keras; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; using Tensorflow.Train; using static Tensorflow.ApiDef.Types; @@ -243,7 +244,7 @@ namespace Tensorflow.Training _last_wrapped_list_snapshot = new List(_storage); } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { check_external_modification(); if (_non_append_mutation_value) diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index f217a052..756024db 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -4,6 +4,8 @@ using Tensorflow.Eager; using Tensorflow.Variables; using Tensorflow.Train; using static Tensorflow.Binding; +using System.Collections.Generic; +using Tensorflow.ModelSaving; namespace Tensorflow { @@ -20,6 +22,7 @@ namespace Tensorflow public string UniqueId => _unique_id; protected bool _in_graph_mode; + internal bool InGraphMode => _in_graph_mode; protected bool _trainable; public bool Trainable => _trainable; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index b31960c7..6093f810 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -17,7 +17,9 @@ using Google.Protobuf; using System; using System.Collections.Generic; +using Tensorflow.Checkpoint; using Tensorflow.NumPy; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow @@ -235,5 +237,12 @@ namespace Tensorflow { return _graph_element.eval(session); } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 61a8956a..7c8812ad 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; @@ -351,7 +352,7 @@ namespace Tensorflow.Keras.Engine return output_tensors; } - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index 1675fba1..ffb6f71b 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; @@ -9,16 +10,16 @@ public abstract partial class Layer { public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); - public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { IDictionary children; if (save_type == SaveType.SAVEDMODEL) { - // TODO: deal with cache. + Debug.Assert(cache is not null); children = TrackableSavedModelSaver.trackable_children(cache); } else diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index b9b01dae..a2f92ba8 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -88,9 +88,29 @@ namespace Tensorflow.Keras.Engine ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; - public Tensor[] input => inboundNodes[0].input_tensors; + public Tensor[] input + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].input_tensors; + } + return null; + } + } public Dictionary> NodesByDepth { get; set; } - public Shape OutputShape => inboundNodes[0].Outputs.shape; + public Shape OutputShape + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].Outputs.shape; + } + return null; + } + } protected List _self_tracked_trackables; public Layer(LayerArgs args) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 59f74cd2..59b205e4 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine bool include_optimizer = true, string save_format = "tf", SaveOptions? options = null, - IDictionary? signatures = null, + ConcreteFunction? signatures = null, bool save_traces = true) { if (save_format != "pb") diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 41f7788e..dfe5b05f 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Engine } } - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { if(save_type == SaveType.SAVEDMODEL) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 76453ca0..6a6e418c 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary? signatures, + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, SaveOptions? options, bool save_traces = true) { if (!overwrite && File.Exists(filepath)) @@ -54,12 +54,7 @@ public partial class KerasSavedModelUtils } var metadata = generate_keras_metadata(saved_nodes, node_paths); - using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, - FileAccess.Write)) - { - var writer = new StreamWriter(f); - writer.Write(metadata.ToString()); - } + File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); if (!include_optimizer) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index 7168e25b..fc7eab3a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -19,7 +19,7 @@ public partial class KerasSavedModelUtils /// /// /// - public static IDictionary wrap_layer_objects(Layer layer, IDictionary serialization_cache) + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) { // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. @@ -55,7 +55,7 @@ public partial class KerasSavedModelUtils /// /// /// - public static IDictionary wrap_layer_functions(Layer layer, IDictionary serialization_cache) + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) { // TODO: deal with type `RevivedLayer` and `Sequential`. diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index a399eaf1..0235f87b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -18,12 +18,12 @@ public abstract class SavedModelSaver public abstract string TrackingMetadata { get; } public abstract IDictionary objects_to_serialize( - IDictionary serialization_cache); + IDictionary> serialization_cache); public abstract IDictionary functions_to_serialize( - IDictionary serialization_cache); + IDictionary> serialization_cache); - public IDictionary trackable_children(IDictionary? serialization_cache) + public IDictionary trackable_children(IDictionary> serialization_cache) { if (!KerasSavedModelUtils.ShouldHaveTraces) { @@ -31,7 +31,6 @@ public abstract class SavedModelSaver } var children = objects_to_serialize(serialization_cache); - return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 7a0ddd21..b092b595 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -19,12 +19,12 @@ public class LayerSavedModelSaver: SavedModelSaver get => Constants.LAYER_IDENTIFIER; } - public override IDictionary objects_to_serialize(IDictionary serialization_cache) + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) { return get_serialized_attributes(serialization_cache).ObjectsToSerialize; } - public override IDictionary functions_to_serialize(IDictionary serialization_cache) + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) { return get_serialized_attributes(serialization_cache).FunctionsToSerialize; } @@ -33,11 +33,21 @@ public class LayerSavedModelSaver: SavedModelSaver /// Generates or retrieves serialized attributes from cache. /// /// - protected SerializedAttributes get_serialized_attributes(IDictionary serialization_cache) + protected ISerializedAttributes get_serialized_attributes(IDictionary> serialization_cache) { // TODO: deal with cache. + IDictionary keras_cache; + if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) + { + keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; + } + else + { + serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary(); + } + if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; - var serialized_attr = SerializedAttributes.Create(_obj); + var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. if (KerasSavedModelUtils.should_skip_serialization(_obj)) @@ -56,7 +66,7 @@ public class LayerSavedModelSaver: SavedModelSaver /// Returns dictionary of serialized attributes. /// /// - private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary serialization_cache) + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) { var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); @@ -75,7 +85,7 @@ public class LayerSavedModelSaver: SavedModelSaver metadata["trainable"] = _obj.Trainable; // metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["dtype"] = policy.serialize(_obj._dtype_policy) - metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); + metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; @@ -92,8 +102,10 @@ public class LayerSavedModelSaver: SavedModelSaver } } - public static LayerConfig get_serialized(Layer obj) + public static IDictionary get_serialized(Layer obj) { - return generic_utils.serialize_keras_object(obj); + // TODO: complete the implmentation (need to revise `get_config`). + return new Dictionary(); + //return generic_utils.serialize_keras_object(obj); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 804ea1a9..ac194c00 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel /// /// Class that tracks and validates all serialization attributes. /// - public abstract class SerializedAttributes + public abstract class SerializedAttributes: ISerializedAttributes { protected IDictionary _object_dict; protected IDictionary _function_dict; diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs index 9d1b3088..0f34ff10 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -50,11 +50,11 @@ public class SaveTest { TrainDir = "mnist", OneHot = false, - ValidationSize = 50000, + ValidationSize = 0, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - model.save("", save_format:"pb"); + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); } } \ No newline at end of file From f2e41a17916b25ff6fd3baf20ed6fc0d651fb4c2 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Thu, 2 Feb 2023 17:34:50 +0800 Subject: [PATCH 08/10] Support autograph.to_graph under graph mode. --- src/TensorFlowNET.Core/Graphs/AutoGraph.cs | 46 +++++++++++++++------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index 2af1a372..ceeca8ab 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; @@ -6,14 +7,14 @@ namespace Tensorflow.Graphs { public class AutoGraph { - public Func to_graph(Func func) + public Func to_graph(Func func, TF_DataType dtype = TF_DataType.TF_INT32) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input = tf.placeholder(tf.int32); + var input = tf.placeholder(dtype); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -26,25 +27,33 @@ namespace Tensorflow.Graphs return (Tensor input) => { - var result = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - func_name, - new[] { input }, - null, - 1); - return result[0]; + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + func_name, + new[] { input }, + null, + 1); + return result[0]; + } + using (var s = tf.Session(input.graph)) + { + var output = func(input); + return output; + } }; } - public Func to_graph(Func func) + public Func to_graph(Func func, params TF_DataType[] dtypes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input1 = tf.placeholder(tf.int32); - var input2 = tf.placeholder(tf.int32); + var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); + var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); var output = func(input1, input2); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -56,13 +65,22 @@ namespace Tensorflow.Graphs return (Tensor a, Tensor b) => { - var result = tf.Runner.TFE_Execute(tf.Context, + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { a, b }, null, 1); - return result[0]; + return result[0]; + } + using (var s = tf.Session(a.graph)) + { + Debug.Assert(a.graph == b.graph); + var output = func(a, b); + return output; + } }; } } From a479e53f3aad18a6272eeddb8f3243b10f3beffb Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Thu, 2 Feb 2023 19:17:15 +0800 Subject: [PATCH 09/10] Add more implementations to the pb model save. --- .../Checkpoint/CheckPointUtils.cs | 10 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 7 +- .../Checkpoint/SaveUtilV1.cs | 18 +-- .../Checkpoint/checkpoint.cs | 23 ++-- .../Checkpoint/functional_saver.cs | 60 ++++++---- src/TensorFlowNET.Core/Eager/execute.cs | 31 +++++ .../Framework/meta_graph.cs | 23 ++++ .../ModelSaving/SaveOptions.cs | 38 +++++- src/TensorFlowNET.Core/Operations/gen_ops.cs | 59 +++++++++- src/TensorFlowNET.Core/Operations/io_ops.cs | 32 +++++ .../Operations/resource_variable_ops.cs | 54 +++++++++ .../Saving/ResourceVariableSaveable.cs | 28 +++++ .../Training/Saving/SaveableObject.cs | 24 +++- .../Saving/SavedModel/SaveableView.cs | 14 +-- .../Training/Saving/SavedModel/save.cs | 83 +++++++------ .../Saving/SavedModel/save_context.cs | 53 +++++++++ .../Training/Saving/SavedModel/utils.cs | 5 + .../Saving/saveable_object_util.py.cs | 109 +++++++++++++----- src/TensorFlowNET.Core/Training/Trackable.cs | 13 ++- .../Variables/BaseResourceVariable.cs | 64 ++++++++++ .../Variables/ResourceVariable.cs | 8 +- .../Variables/UninitializedVariable.cs | 70 +++++++++++ .../Engine/Functional.GetConfig.cs | 2 +- .../Engine/Layer.Serialize.cs | 2 +- .../Layers/Core/InputLayer.cs | 3 + .../Saving/SavedModel/Save.cs | 4 +- .../Saving/SavedModel/SaveImpl.cs | 24 ++-- .../Saving/SavedModel/base_serialization.cs | 2 +- .../Saving/SavedModel/layer_serialization.cs | 63 ++++++++-- .../Utils/generic_utils.cs | 9 +- 30 files changed, 775 insertions(+), 160 deletions(-) create mode 100644 src/TensorFlowNET.Core/Eager/execute.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs create mode 100644 src/TensorFlowNET.Core/Variables/UninitializedVariable.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 70d77155..cd37703b 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using Tensorflow.Train; @@ -85,17 +86,18 @@ public static class CheckPointUtils } } - public static string get_full_name(Trackable var) + public static string get_full_name(Trackable variable) { // TODO: This state is not correct, the whole framework need to be updated in the future. - if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) + if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) { return ""; } // skip the check of attribute `_save_slice_info` . - + // TODO: Need to be revised!!! - return ((ResourceVariable)(object)var).Name; + Debug.Assert(variable is BaseResourceVariable); + return ((BaseResourceVariable)variable).Name; } public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index e646f1f0..84e0ca4e 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint ); public static class SaveUtil { - public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -39,7 +39,7 @@ namespace Tensorflow.Checkpoint var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); - Dictionary feed_additions; + Dictionary feed_additions; if(cache is null) { feed_additions = null; @@ -125,7 +125,7 @@ namespace Tensorflow.Checkpoint { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; - var trackable = td.object_to_save; + Trackable trackable = null; IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { @@ -134,6 +134,7 @@ namespace Tensorflow.Checkpoint else { tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + trackable = td.object_to_save; } if(trackable is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index d8e251ec..4f1d04d2 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -44,19 +44,19 @@ public static class SaveUtilV1 return (checkpoint_factory_map, null); } - public static (List, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, + public static (List, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { if (to_graph is not null) { - to_graph.as_default(); + var g = to_graph.as_default(); var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` - var serialized = graph_proto.ToByteString().ToString(); - var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); + tf.device("/cpu:0"); + var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + g.Exit(); return (named_saveable_objects, registered_savers); } else @@ -65,7 +65,7 @@ public static class SaveUtilV1 { var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` + tf.device("/cpu:0"); var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); return (named_saveable_objects, registered_savers); @@ -73,7 +73,7 @@ public static class SaveUtilV1 } } - public static (List, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, + public static (List, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); @@ -129,7 +129,7 @@ public static class SaveUtilV1 return object_graph_proto; } - private static (List, object?, object?) add_attributes_to_object_graph(IList trackable_objects, + private static (List, object?, IDictionary>?) add_attributes_to_object_graph(IList trackable_objects, TrackableObjectGraph object_graph_proto, IDictionary node_ids, IDictionary object_names, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -216,7 +216,7 @@ public static class SaveUtilV1 public record class CheckpointFactoryData ( - Maybe factory, + Maybe factory, string name, string checkpoint_key ); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index c9bee0db..0c2862da 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -33,7 +33,7 @@ public class TrackableSaver } - private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -42,26 +42,27 @@ public class TrackableSaver if(object_graph_tensor is null) { - // tensorflow python: `with ops.device("/cpu:0"):` - object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + tf.device("/cpu:0"); + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); } else { - feed_additions[object_graph_tensor] = graph_proto.ToString(); + feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); } Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); - if (serialized_tensors.ContainsKey(Trackable.None)) + if (!serialized_tensors.ContainsKey(Trackable.None)) { - serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + serialized_tensors[Trackable.None] = new Dictionary>>(); } + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; return (serialized_tensors, feed_additions, registered_savers, graph_proto); } - private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); - Func<(Tensor, IDictionary)> run_save = () => + Func<(Tensor, IDictionary)> run_save = () => { if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) { @@ -86,11 +87,11 @@ public class TrackableSaver return run_save(); } - private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); - Func<(Tensor, IDictionary)> run_save = () => + Func<(Tensor, IDictionary)> run_save = () => { if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) { @@ -124,7 +125,7 @@ public class TrackableSaver options = new CheckpointOptions(); } - Dictionary feed_dict = new(); + Dictionary feed_dict = new(); bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); if (checkpoint_number is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index c4a03985..90bbccf0 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -12,6 +12,8 @@ using System.Linq; using Tensorflow.Operations; using Tensorflow.Training; using Tensorflow.Graphs; +using System.Xml.Linq; +using System.Diagnostics; namespace Tensorflow.Checkpoint { @@ -31,6 +33,10 @@ namespace Tensorflow.Checkpoint { return Func.DynamicInvoke(args); } + public TR Invoke() + { + return Func.Invoke(); + } } internal record class FunctionHolder(Func Func) : IFunctionHolder { @@ -164,7 +170,6 @@ namespace Tensorflow.Checkpoint { var slice_spec = slice.Key; var maybe_tensor = slice.Value; - // TODO: deal with other types. Currently only `SaveSpec` is allowed. if(maybe_tensor.DataType == typeof(SaveSpec)) { var spec = maybe_tensor.GetValueB(); @@ -284,14 +289,16 @@ namespace Tensorflow.Checkpoint var obj = pair.Key; var tensor_dict = pair.Value; IFunctionHolder restore_fn; - if(obj is null) + if(obj == Trackable.None) { restore_fn = new FunctionHolder(() => null); } else { - restore_fn = null; - // TODO: implement obj._restore_from_tensors + restore_fn = new FunctionHolder>>, IDictionary>(x => + { + return obj._restore_from_tensors(x); + }); } foreach(var item in tensor_dict) @@ -343,7 +350,7 @@ namespace Tensorflow.Checkpoint } } - public Operation save(string file_prefix, CheckpointOptions? options= null) + public Operation save(Tensor file_prefix, CheckpointOptions? options= null) { if(options is null) { @@ -351,9 +358,9 @@ namespace Tensorflow.Checkpoint } tf.device("CPU"); // may be risky. - // TODO: optimize the implementation with new APIs adding to `string_ops`. - string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; - var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); + var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), + constant_op.constant(".part"), constant_op.constant("_temp/part")); + var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); Operation save_fn() @@ -385,7 +392,7 @@ namespace Tensorflow.Checkpoint { string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; tf.device(merge_device); - return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); } } @@ -400,9 +407,9 @@ namespace Tensorflow.Checkpoint } } - public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); + public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); - public IDictionary restore(string file_prefix, CheckpointOptions? options = null) + public IDictionary restore(Tensor file_prefix, CheckpointOptions? options = null) { if(options is null) { @@ -496,8 +503,10 @@ namespace Tensorflow.Checkpoint public SaverDef to_proto() { var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); - var save_tensor = _traced_save(filename_tensor); - var restore_op = _traced_restore(filename_tensor).op; + var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); + var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); + var save_tensor = traced_save_func(filename_tensor); + var restore_op = traced_restore_func(filename_tensor).op; return new SaverDef() { FilenameTensorName = filename_tensor.name, @@ -507,10 +516,9 @@ namespace Tensorflow.Checkpoint }; } - [AutoGraph] private Tensor _traced_save(Tensor file_prefix) { - var save_op = save(file_prefix.StringData()[0]); + var save_op = save(file_prefix); tf.device("cpu:0"); using (ops.control_dependencies(new object[]{ save_op })) { @@ -518,24 +526,34 @@ namespace Tensorflow.Checkpoint } } - [AutoGraph] private Tensor _traced_restore(Tensor file_prefix) { - var restore_op = restore(file_prefix.StringData()[0]); + var restore_op = restore(file_prefix); tf.device("cpu:0"); - using (ops.control_dependencies(new object[] { restore_op })) + using (ops.control_dependencies(restore_op.Values.ToArray())) { return array_ops.identity(file_prefix); } } - private static Tensor registered_saver_filename(string filename, string saver_name) + public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) + { + Dictionary>>> serialized_tensors = new(); + foreach (var saveable in saveables) + { + var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); + serialized_tensors[trackable] = trackable.serialize_to_tensors(); + } + return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); + } + + private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) { - return tf.constant($"{filename}-{saver_name}"); + return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); } private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) { - return filename_tensor; + return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); } } } diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs new file mode 100644 index 00000000..cb3ea4d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Contexts; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + internal class execute + { + public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) + { + var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); + var types = v.Select(t => t.dtype.as_datatype_enum()); + return (types.ToArray(), v.ToArray()); + } + public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + string device_name = ctx.DeviceName; + + ctx.ensure_initialized(); + var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); + + return tensors; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index cce13b55..c3616faf 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -406,5 +406,28 @@ namespace Tensorflow meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; } + + /// + /// Extract the Op name from a Tensor name. + /// + /// + /// + public static string op_name(string tensor_name) + { + if (string.IsNullOrEmpty(tensor_name)) + { + throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); + } + + if (tensor_name.StartsWith("^")) + { + tensor_name = tensor_name.Substring(1); + } + if (tensor_name.Contains(":")) + { + return tensor_name.Split(':')[0]; + } + return tensor_name; + } } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index fce42850..45ebd884 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -14,11 +14,47 @@ namespace Tensorflow.ModelSaving public IDictionary? function_aliases { get; set; } = null; public string? experimental_io_device { get; set; } = null; // TODO: experimental - public Object? experimental_variable_polict { get; set; } = null; + public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; } } + + public class VariablePolicy + { + public string Policy { get; } + private VariablePolicy(string policy) + { + Policy = policy; + } + public static VariablePolicy None = new(null); + public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); + public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); + + public bool save_variable_devices() + { + return this != VariablePolicy.None; + } + + /// + /// Tries to convert `obj` to a VariablePolicy instance. + /// + /// + /// + public static VariablePolicy from_obj(object obj) + { + if (obj is null) return VariablePolicy.None; + if (obj is VariablePolicy) return (VariablePolicy)obj; + var key = obj.ToString().ToLower(); + return key switch + { + null => VariablePolicy.None, + "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, + _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") + }; + } + } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 11cb6de8..956be96b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -1,6 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Xml.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow.Operations @@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations /// path in the input checkpoint_prefixes. This is useful when those paths are non /// user-facing temporary locations. /// - public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") - { + public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); + result = null; + return null; + //try + //{ + // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); + // result = null; + // return null; + //} + //catch (System.Exception) + //{ + // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, + // allow_missing_files: allow_missing_files, name: name, ctx: ctx); + //} + } var dict = new Dictionary(); dict["checkpoint_prefixes"] = checkpoint_prefixes; dict["destination_prefix"] = destination_prefix; - if (delete_old_dirs.HasValue) - dict["delete_old_dirs"] = delete_old_dirs.Value; + dict["delete_old_dirs"] = delete_old_dirs; var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); return op; } + //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) + //{ + // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); + // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); + // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; + // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; + // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); + // result = null; + // return null; + //} + /// /// Transforms a spectrogram into a form that's useful for speech recognition. /// @@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations /// public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); + return result[0]; + } var dict = new Dictionary(); dict["input"] = input; dict["pattern"] = pattern; @@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations /// public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); + return result[0]; + } var dict = new Dictionary(); dict["basename"] = basename; dict["shard"] = shard; @@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations /// public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); + return result[0]; + } var dict = new Dictionary(); dict["inputs"] = inputs; if (separator != null) diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs index 4f276e36..35c5877f 100644 --- a/src/TensorFlowNET.Core/Operations/io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -14,7 +14,9 @@ limitations under the License. ******************************************************************************/ +using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -23,11 +25,41 @@ namespace Tensorflow { public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute( + new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); + result = null; + return null; + } + catch (System.Exception) + { + return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); + } + } var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); return _op; } + public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) + { + DataType[] attr_dtypes; + (attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); + var attrs = new object[] { "dtypes", attr_dtypes }; + + var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); + result = null; + return null; + } + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index d5a32c10..1b1fa003 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,7 +17,9 @@ using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.ModelSaving; using Tensorflow.Train; +using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -177,5 +179,57 @@ namespace Tensorflow return HandleData.Parser.ParseFrom(handle.BufferToArray()); } } + + /// + /// Copies an existing variable to a new graph, with no initializer. + /// + /// + public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) + { + var new_variable = new UninitializedVariable( + trainable: variable.Trainable, + shape: variable.shape, + dtype: variable.dtype, + name: variable.SharedName, + aggregation: variable.Aggregation, + extra_handle_data: null); + new_variable._maybe_initialize_trackable(); + return new_variable; + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// + /// + /// + /// + /// + public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) + { + // lack of API: `proto.Variable.SetInParent()`. + if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) + { + throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + + $"unexpected suffix in the name (expected ':0') which won't be restored."); + } + if(proto.Variable is null) + { + proto.Variable = new SavedVariable(); + } + proto.Variable.Name = meta_graph.op_name(resource_variable.Name); + proto.Variable.Trainable = resource_variable.Trainable; + proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); + // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. + proto.Variable.Aggregation = resource_variable.Aggregation; + proto.Variable.Shape = resource_variable.shape.as_proto(); + + if (options.experimental_variable_policy.save_variable_devices()) + { + if (!string.IsNullOrEmpty(resource_variable.Device)) + { + proto.Variable.Device = resource_variable.Device; + } + } + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 167c635a..2d23a325 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public class ResourceVariableSaveable : MySaveableObject @@ -35,6 +37,32 @@ namespace Tensorflow this.name = name; } + public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) + { + _var_device = var.Device; + _var_shape = var.shape; + + Tensor _read_variable_closure(BaseResourceVariable v) + { + tf.device(v.Device); + if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + tf.device("/device:CPU:0"); + return array_ops.identity(x); + } + + this.handle_op = var.Handle; + var tensor = _read_variable_closure(var); + + var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); + _op = var; + specs = new SaveSpec[] { spec }; + this.name = name; + } + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) { var restored_tensor = restored_tensors[0]; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 6239030b..43d36dba 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -14,11 +14,31 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Checkpoint; + namespace Tensorflow { public class MySaveableObject { - public Tensor op; + protected Maybe _op; + public Tensor op + { + get + { + if(_op.DataType == typeof(Tensor)) + { + return _op.GetValueA(); + } + else + { + throw new TypeError("The _op is not a tensor."); + } + } + set + { + _op = value; + } + } public SaveSpec[] specs; public string name; public string device; @@ -35,7 +55,7 @@ namespace Tensorflow public MySaveableObject(Tensor op, SaveSpec[] specs, string name) { - this.op = op; + this._op = op; this.specs = specs; this.name = name; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6700e277..6132e025 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -10,6 +10,7 @@ using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow; @@ -75,7 +76,7 @@ public class SaveableView private void initialize_save_and_restore_functions() { // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. - SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. _obj_to_registered_saver = new(); _saveable_objects_map = new(); @@ -191,7 +192,7 @@ public class SaveableView /// /// /// - public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index, SaveOptions options) + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index) { SavedObjectGraph proto = new(); fill_object_graph_proto(proto); @@ -203,21 +204,20 @@ public class SaveableView { var obj = _nodes[i]; var obj_proto = proto.Nodes[i]; - write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), - options); + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); } return proto; } private static void write_object_proto(Trackable obj, SavedObject proto, - IDictionary asset_file_def_index, Func> list_children_fn, SaveOptions options) + IDictionary asset_file_def_index, Func> list_children_fn) { // skip the process of type Asset if (resource_variable_ops.is_resource_variable(obj)) { - // TODO: complete it. - throw new NotImplementedException(); + var options = SaveContext.get_save_options(); + (obj as BaseResourceVariable).write_object_proto(proto, options); } else if (obj is Function) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index f3f273b8..d82d49d8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -10,6 +10,7 @@ using Tensorflow.ModelSaving; using Tensorflow.Train; using Tensorflow.Exceptions; using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow; @@ -43,7 +44,7 @@ public static partial class SavedModelUtils { SavedModelUtils.get_or_create_variables_dir(export_dir); CheckpointOptions ckpt_options = new(options.experimental_io_device); - object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options); } BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); @@ -68,6 +69,7 @@ public static partial class SavedModelUtils var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); File.WriteAllBytes(path, saved_model.ToByteArray()); + //File.WriteAllText(path, saved_model.ToString()); if (options.save_debug_info) { @@ -83,45 +85,48 @@ public static partial class SavedModelUtils Dictionary>) _build_meta_graph(Trackable obj, ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { - if (ops.inside_function()) + using (SaveContext.save_context(options)) { - throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + - "Move the call to the outer eagerly-executed context."); - } + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } - if (meta_graph_def is null) - { - meta_graph_def = new MetaGraphDef(); - } + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } - AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); - if (signatures is null) - { - signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); - } - - // TODO: process of aignatures and wrapped_functions + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is null) + { + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); + } - SaveableView saveable_view = new SaveableView(augmented_graph_view, options); - TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); - var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, - options.namespace_white_list, options.experimental_custom_gradients); - if (options.function_aliases is not null) - { - var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; - foreach (var pair in options.function_aliases) + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) { - var alias = pair.Key; - var func = pair.Value; - // TODO: complete it. - throw new NotImplementedException(); + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } } - } - var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); - meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); - return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } } private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, @@ -134,7 +139,7 @@ public static partial class SavedModelUtils Dictionary object_map; Dictionary tensor_map; AssetInfo asset_info; - exported_graph.as_default(); + var g = exported_graph.as_default(); (object_map, tensor_map, asset_info) = saveable_view.map_resources(); // TODO: deal with signatures. if (save_custom_gradients) @@ -161,15 +166,23 @@ public static partial class SavedModelUtils // Lack `CopyFrom` API // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + g.Exit(); + foreach (var obj in object_map.Values) { obj._maybe_initialize_trackable(); } + // TODO: add the implementation of `call_with_mapped_functions`. var (named_saveable_objects, registered_savers) = SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); - - // TODO: complete the save of checkpoints with `MultiDeviceSaver`. + var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); + + var eg = exported_graph.as_default(); + var saver_def = saver.to_proto(); + meta_graph_def.SaverDef = saver_def; + eg.Exit(); + saveable_view.dependency_sorted_node_ids(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs new file mode 100644 index 00000000..4cfe0b69 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.ModelSaving; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A context for building a graph of SavedModel. + /// + public static class SaveContext + { + // TODO: make it thead safe. + private static bool _in_save_context = false; + private static SaveOptions _save_options = null; + + public static bool in_save_context() => _in_save_context; + public static SaveOptions get_save_options() + { + if (!in_save_context()) + { + throw new ValueError("Not in a SaveContext."); + } + return _save_options; + } + public static SaveContextHandler save_context(SaveOptions options) + { + return new SaveContextHandler(options); + } + + public class SaveContextHandler: IDisposable + { + private bool _old_in_save_context; + private SaveOptions _old_save_options; + public SaveContextHandler(SaveOptions options) + { + if (SaveContext.in_save_context()) + { + throw new ValueError("Already in a SaveContext."); + } + _old_in_save_context = SaveContext._in_save_context; + SaveContext._in_save_context = true; + _old_save_options = SaveContext._save_options; + SaveContext._save_options = options; + } + public void Dispose() + { + SaveContext._in_save_context = _old_in_save_context; + SaveContext._save_options = _old_save_options; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs index 723419f6..2deff027 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -28,6 +28,11 @@ public static partial class SavedModelUtils return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); } + public static string get_variables_path(string export_dir) + { + return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); + } + /// /// Return assets sub-directory, or create one if it doesn't exist. /// diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 7066b366..582e2431 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using Tensorflow.Checkpoint; +using Tensorflow.Operations.Activation; using Tensorflow.Train; using Tensorflow.Training; using static Tensorflow.Binding; @@ -117,8 +118,7 @@ namespace Tensorflow } else { - Debug.Assert(variable is ResourceVariable); - yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); + yield return new ResourceVariableSaveable(variable, "", name); } } else @@ -215,7 +215,7 @@ namespace Tensorflow return names_to_saveables; } - public static IDictionary> saveable_objects_from_trackable(Trackable obj) + public static IDictionary> saveable_objects_from_trackable(Trackable obj) { // skip the process of type `PythonState` @@ -251,7 +251,7 @@ namespace Tensorflow specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); } } - Dictionary> res = new(); + Dictionary> res = new(); res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); return res; } @@ -270,25 +270,6 @@ namespace Tensorflow { return tf.compat.as_str(x); } - } - - public class SaveableCompatibilityConverter: Trackable - { - private Trackable _obj; - private IList _saveables; - public SaveableCompatibilityConverter(Trackable obj, IList saveables) - { - _obj= obj; - _saveables= saveables; - } - - public Trackable Obj => _obj; - public IList mySaveables=> _saveables; - - public override IDictionary>> serialize_to_tensors() - { - return saveable_object_to_tensor_dict(_saveables); - } /// /// Converts a list of SaveableObjects to a tensor dictionary. @@ -299,11 +280,11 @@ namespace Tensorflow Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { - foreach(var spec in saveable.specs) + foreach (var spec in saveable.specs) { // skip the check that if `spec` is callable. - var name = saveable_object_util.convert_to_string(spec.name); - var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + var name = convert_to_string(spec.name); + var slice_spec = convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; @@ -316,5 +297,81 @@ namespace Tensorflow } return tensor_dict; } + + /// + /// Generates `Trackable._restore_from_tensors` from SaveableObjects. + /// + /// + public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) + { + return (restored_tensors) => + { + Dictionary restored_ops = new(); + + foreach(var saveable in saveables) + { + List saveable_restored_tensors = new(); + foreach(var spec in saveable.specs) + { + var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + + var maybe_tensor = restored_tensors[name]; + IDictionary dict; + if(maybe_tensor.DataType == typeof(Tensor)) + { + dict = new Dictionary(); + dict[""] = maybe_tensor.GetValueA(); + } + else + { + dict = maybe_tensor.GetValueB(); + } + saveable_restored_tensors.Add(dict[slice_spec]); + } + restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); + } + return restored_ops; + }; + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private object _obj; + private IList _saveables; + public SaveableCompatibilityConverter(object obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public object Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary>> serialize_to_tensors() + { + return saveable_object_util.saveable_object_to_tensor_dict(_saveables); + } + + /// + /// Returns the restore ops defined in the Saveables. + /// + /// + /// + public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + List expected_keys = new(); + foreach(var saveable in _saveables) + { + expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); + } + if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) + { + throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + + $"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); + } + return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index a677044a..434d51b6 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -42,11 +42,11 @@ namespace Tensorflow.Train protected IList _unconditional_checkpoint_dependencies; - protected IDictionary> _self_saveable_object_factories = - new Dictionary>(); + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); private bool _manual_tracking = true; - private static Trackable _none = new Function(); + private static Trackable _none = new AutoTrackable(); /// /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. /// The `None` can be any object that inherits `Trackable`. @@ -225,7 +225,7 @@ namespace Tensorflow.Train return self_tensor_map.Keys.ToList(); } - public virtual IDictionary> gather_saveables_for_checkpoint() + public virtual IDictionary> gather_saveables_for_checkpoint() { if (saveable_object_util.trackable_has_serialize_to_tensor(this)) { @@ -251,6 +251,11 @@ namespace Tensorflow.Train { throw new NotImplementedException(); } + + public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + throw new NotImplementedException(); + } } public record class TrackableReference(string Name, Trackable Refer); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 756024db..4005d564 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -6,6 +6,8 @@ using Tensorflow.Train; using static Tensorflow.Binding; using System.Collections.Generic; using Tensorflow.ModelSaving; +using System.Diagnostics; +using Tensorflow.Checkpoint; namespace Tensorflow { @@ -13,6 +15,7 @@ namespace Tensorflow { protected string _name; public virtual string Name => _handle_name; + public virtual string SharedName => _name; protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; @@ -50,6 +53,7 @@ namespace Tensorflow public Graph Graph => handle.graph; public string Device => handle.Device; EagerResourceDeleter eager_resource_deleter; + public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; public BaseResourceVariable() { @@ -77,6 +81,11 @@ namespace Tensorflow _handle = handle.EagerTensorHandle.DangerousGetHandle(); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); } + else if(handle is null) + { + // TODO: fix this dangerous change. + _handle = IntPtr.Zero; + } else { _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); @@ -247,5 +256,60 @@ namespace Tensorflow else return value(); } + + public override (IDictionary, IDictionary) map_resources(SaveOptions save_options) + { + BaseResourceVariable new_variable; + if (save_options.experimental_variable_policy.save_variable_devices()) + { + tf.device(this.Device); + Debug.Assert(this is ResourceVariable); + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + else + { + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + Dictionary obj_map = new(); + Dictionary resource_map = new(); + obj_map[this] = new_variable; + resource_map[this.handle] = new_variable.handle; + return (obj_map, resource_map); + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// ubclasses of ResourceVariables could choose to override this method to + /// customize extra information to provide when saving a SavedModel. + /// + /// + /// + public virtual void write_object_proto(SavedObject proto, SaveOptions options) + { + resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); + } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } + + public Tensor is_initialized(string name = null) + { + return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); + } + + public Tensor read_value_no_copy() + { + Tensor value = null; + tf_with(ops.name_scope("Read"), _ => + { + // TODO: `no_copy = true`. + value = _read_variable_op(); + }); + return array_ops.identity(value); + } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 6093f810..1645d713 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -41,6 +41,7 @@ namespace Tensorflow VariableAggregation aggregation = VariableAggregation.None, Shape shape = null) { + Aggregation = aggregation; if (variable_def != null) { if (initial_value != null) @@ -237,12 +238,5 @@ namespace Tensorflow { return _graph_element.eval(session); } - - public override IDictionary> gather_saveables_for_checkpoint() - { - var res = new Dictionary>(); - res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; - return res; - } } } diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs new file mode 100644 index 00000000..6c034995 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Variables +{ + /// + /// A variable with no initializer. + /// + public sealed class UninitializedVariable: BaseResourceVariable + { + // TODO: complete the arg list. + public UninitializedVariable( + bool trainable = true, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null, + Tensor extra_handle_data = null) + { + string unique_id = ""; + string handle_name = ""; + tf_with(ops.init_scope(), (x) => + { + _in_graph_mode = !tf.Context.executing_eagerly(); + tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => + { + handle_name = ops.name_from_scope_name(name); + string? shared_name; + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } + else + { + unique_id = $"{handle_name}-{ops.uid()}"; + shared_name = null; + } + var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); + // skip the assignment of `handle._parent_trackable` because of lack of API. + // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. + + if (_in_graph_mode) + { + tf_with(ops.name_scope("Read"), _ => + { + tf.device(handle.Device); + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + // _maybe_set_handle_data(dtype, handle, value) + _graph_element = value; + }); + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); + } + else + { + _graph_element = null; + } + }); + }); + _shape = shape; + _dtype = dtype; + base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 23c40fbf..a221444b 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Engine } } - var layer_config = generic_utils.serialize_keras_object(layer); + var layer_config = generic_utils.serialize_layer_to_config(layer); layer_config.Name = layer.Name; layer_config.InboundNodes = filtered_inbound_nodes; layer_configs.Add(layer_config); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index ffb6f71b..fc405d87 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine; public abstract partial class Layer { - public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 6b064716..03b4b742 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -18,6 +18,7 @@ using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -105,5 +106,7 @@ namespace Tensorflow.Keras.Layers { return new InputLayer(args as InputLayerArgs); } + + public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 6a6e418c..4ff8f02f 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -55,6 +55,7 @@ public partial class KerasSavedModelUtils var metadata = generate_keras_metadata(saved_nodes, node_paths); File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); + //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString()); if (!include_optimizer) { @@ -100,7 +101,8 @@ public partial class KerasSavedModelUtils Identifier = layer.ObjectIdentifier, Metadata = layer.TrackingMetadata }; - + + metadata.Nodes.Add(saved_object); } return metadata; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index fc7eab3a..f7e1bf45 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -24,26 +24,26 @@ public partial class KerasSavedModelUtils // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. // TODO: change the inherits of `Variable` and revise the implmentation. - var variables = layer.Variables.Select(x => + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); - var trainable_variables = layer.TrainableVariables.Select(x => + })); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); - var non_trainable_variables = layer.non_trainable_variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); + })); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); Dictionary res = new(); - res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); - res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); - res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); + res["variables"] = variables; + res["trainable_variables"] = trainable_variables; + res["non_trainable_variables"] = non_trainable_variables; res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); return res; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 0235f87b..60c4ee5b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public abstract class SavedModelSaver { - private Trackable _obj; + protected Trackable _obj; public SavedModelSaver(Trackable obj) { _obj = obj; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index b092b595..655127af 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -2,6 +2,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; using Tensorflow.Train; @@ -9,10 +10,11 @@ namespace Tensorflow.Keras.Saving.SavedModel; public class LayerSavedModelSaver: SavedModelSaver { - private Layer _obj; + private Layer _layer; public LayerSavedModelSaver(Layer obj): base(obj) { _obj = obj; + _layer = obj; } public override string ObjectIdentifier { @@ -68,8 +70,8 @@ public class LayerSavedModelSaver: SavedModelSaver /// private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) { - var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); - var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); + var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache); functions["_default_save_signature"] = null; @@ -81,17 +83,18 @@ public class LayerSavedModelSaver: SavedModelSaver get { JObject metadata = new JObject(); - metadata["name"] = _obj.Name; - metadata["trainable"] = _obj.Trainable; + metadata["name"] = _layer.Name; + metadata["trainable"] = _layer.Trainable; // metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["dtype"] = policy.serialize(_obj._dtype_policy) - metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); + metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; - metadata["autocast"] = _obj.AutoCast; - - metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings + metadata["autocast"] = _layer.AutoCast; + + var temp = JObject.FromObject(get_serialized(_layer)); + metadata.Merge(temp, new JsonMergeSettings { // Handle conflicts by using values from obj2 MergeArrayHandling = MergeArrayHandling.Merge @@ -108,4 +111,46 @@ public class LayerSavedModelSaver: SavedModelSaver return new Dictionary(); //return generic_utils.serialize_keras_object(obj); } +} + +public class InputLayerSavedModelSaver: SavedModelSaver +{ + public InputLayerSavedModelSaver(Layer obj) : base(obj) + { + + } + public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER; + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override string TrackingMetadata + { + get + { + if(_obj is not Layer) + { + throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); + } + var layer = (Layer)_obj; + var info = new + { + class_name = layer.GetType().Name, + name = layer.Name, + dtype = layer.DType, + //sparse = layer.sparse, + //ragged = layer.ragged, + batch_input_shape = layer.BatchInputShape, + config = layer.get_config() + }; + return JsonConvert.SerializeObject(info); + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index c2839cdc..68903eb2 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Collections; +using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.Saving; @@ -22,7 +24,12 @@ namespace Tensorflow.Keras.Utils { public class generic_utils { - public static LayerConfig serialize_keras_object(ILayer instance) + /// + /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. + /// + /// + /// + public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); return new LayerConfig From 2ab0bdbc8690b0048eaeb5ab5069e042b1a88d25 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Fri, 3 Feb 2023 19:08:50 +0800 Subject: [PATCH 10/10] Add more implementations to the keras part of pb model save. --- .../ArgsDefinition/Activation/SoftmaxArgs.cs | 17 ++- .../ArgsDefinition/AutoSerializeLayerArgs.cs | 19 +++ .../Keras/ArgsDefinition/Core/DenseArgs.cs | 42 +++++- .../ArgsDefinition/Core/InputLayerArgs.cs | 17 ++- .../Keras/ArgsDefinition/DataAdapterArgs.cs | 3 +- .../Keras/ArgsDefinition/DataHandlerArgs.cs | 3 +- .../Keras/ArgsDefinition/LayerArgs.cs | 31 +++-- .../Keras/ArgsDefinition/NodeArgs.cs | 6 +- .../Keras/ArgsDefinition/OptimizerV2Args.cs | 6 +- .../ArgsDefinition/Reshaping/FlattenArgs.cs | 7 +- .../CustomizedActivationJsonConverter.cs | 50 +++++++ .../Common/CustomizedAxisJsonConverter.cs | 48 +++++++ .../CustomizedNodeConfigJsonConverter.cs | 73 ++++++++++ .../Common/CustomizedShapeJsonConverter.cs | 67 ++++++++++ .../Keras/Engine/InputSpec.cs | 31 ++++- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 5 +- .../Keras/Saving/IKerasConfig.cs | 15 +++ .../Keras/Saving/LayerConfig.cs | 9 +- .../Keras/Saving/ModelConfig.cs | 9 +- .../Keras/Saving/NodeConfig.cs | 7 +- .../Keras/Saving/TensorShapeConfig.cs | 21 +++ src/TensorFlowNET.Core/NumPy/Axis.cs | 11 +- src/TensorFlowNET.Core/Numpy/Shape.cs | 3 + .../Operations/Initializers/Constant.cs | 10 ++ .../Operations/Initializers/GlorotUniform.cs | 10 +- .../Operations/Initializers/IInitializer.cs | 7 + .../Operations/Initializers/Ones.cs | 7 + .../Operations/Initializers/Orthogonal.cs | 5 + .../Operations/Initializers/RandomNormal.cs | 12 ++ .../Operations/Initializers/RandomUniform.cs | 12 ++ .../Initializers/TruncatedNormal.cs | 11 ++ .../Initializers/VarianceScaling.cs | 13 ++ .../Operations/Initializers/Zeros.cs | 5 + .../Operations/NnOps/RNNCell.cs | 5 +- .../Tensorflow.Binding.csproj | 1 + src/TensorFlowNET.Core/Tensors/dtypes.cs | 18 +++ .../{ITrackable.cs => IWithTrackable.cs} | 2 +- src/TensorFlowNET.Core/Training/Trackable.cs | 2 +- .../Engine/Functional.GetConfig.cs | 31 +++-- src/TensorFlowNET.Keras/Engine/Functional.cs | 18 +++ src/TensorFlowNET.Keras/Engine/Layer.cs | 6 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 6 +- .../Layers/Activation/ELU.cs | 1 + .../Layers/Activation/Exponential.cs | 1 + .../Layers/Activation/SELU.cs | 9 +- .../Layers/Attention/Attention.cs | 3 +- .../Layers/Attention/BaseDenseAttention.cs | 3 +- .../Layers/Convolution/Conv2DTranspose.cs | 1 + .../Layers/Convolution/Convolutional.cs | 1 + src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 1 + .../Layers/Core/Embedding.cs | 1 + .../Layers/Cropping/Cropping1D.cs | 1 + .../Layers/Cropping/Cropping2D.cs | 3 +- .../Layers/Cropping/Cropping3D.cs | 3 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 4 +- .../Layers/Merging/Concatenate.cs | 1 + .../Layers/Merging/Merge.cs | 1 + .../Normalization/BatchNormalization.cs | 1 + .../Normalization/LayerNormalization.cs | 1 + .../Layers/Reshaping/Permute.cs | 1 + .../Layers/Rnn/SimpleRNN.cs | 1 + .../Layers/Rnn/StackedRNNCells.cs | 3 +- .../Saving/SavedModel/Save.cs | 2 +- .../Saving/SavedModel/layer_serialization.cs | 33 +++-- .../Saving/TensorShapeConfig.cs | 15 --- .../Saving/serialization.cs | 125 ++++++++++++++++++ .../Utils/base_layer_utils.cs | 2 +- .../Utils/generic_utils.cs | 14 +- .../Layers/ModelSaveTest.cs | 5 +- test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 40 ++++-- 70 files changed, 849 insertions(+), 109 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs create mode 100644 src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs rename src/TensorFlowNET.Core/Training/{ITrackable.cs => IWithTrackable.cs} (82%) delete mode 100644 src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs create mode 100644 src/TensorFlowNET.Keras/Saving/serialization.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs index ca35d75d..a37973bc 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs @@ -1,9 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class SoftmaxArgs : LayerArgs { - public Axis axis { get; set; } = -1; - } + public class SoftmaxArgs : LayerArgs + { + [JsonProperty("axis")] + public Axis axis { get; set; } = -1; + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs new file mode 100644 index 00000000..66b34a1a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -0,0 +1,19 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AutoSerializeLayerArgs: LayerArgs + { + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs index e9b3c2fd..8f4facbd 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs @@ -1,13 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; +using System.Xml.Linq; +using Tensorflow.Operations.Initializers; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { + // TODO: `activity_regularizer` public class DenseArgs : LayerArgs { /// /// Positive integer, dimensionality of the output space. /// + [JsonProperty("units")] public int Units { get; set; } /// @@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition /// public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } + /// /// Whether the layer uses a bias vector. /// + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; /// /// Initializer for the `kernel` weights matrix. /// + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; /// /// Initializer for the bias vector. /// + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; /// /// Regularizer function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } /// /// Constraint function applied to the bias vector. /// + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } + + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs index 723109c2..be43e0a6 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs @@ -1,9 +1,22 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Tensorflow.Keras.Common; + +namespace Tensorflow.Keras.ArgsDefinition { public class InputLayerArgs : LayerArgs { + [JsonIgnore] public Tensor InputTensor { get; set; } - public bool Sparse { get; set; } + [JsonProperty("sparse")] + public virtual bool Sparse { get; set; } + [JsonProperty("ragged")] public bool Ragged { get; set; } + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index f3cca438..8ce1ec65 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataAdapterArgs + public class DataAdapterArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index b6e6849b..fd603a85 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataHandlerArgs + public class DataHandlerArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index 4df4fb2b..febf1417 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -1,51 +1,54 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class LayerArgs + [JsonObject(MemberSerialization.OptIn)] + public class LayerArgs: IKerasConfig { /// /// Indicates whether the layer's weights are updated during training /// and whether the layer's updates are run during training. /// - public bool Trainable { get; set; } = true; - - public string Name { get; set; } + public virtual bool Trainable { get; set; } = true; + public virtual string Name { get; set; } /// /// Only applicable to input layers. /// - public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; + public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; /// /// Whether the `call` method can be used to build a TF graph without issues. /// This attribute has no effect if the model is created using the Functional /// API. Instead, `model.dynamic` is determined based on the internal layers. /// - public bool Dynamic { get; set; } = false; + public virtual bool Dynamic { get; set; } = false; /// /// Only applicable to input layers. /// - public Shape InputShape { get; set; } + public virtual Shape InputShape { get; set; } /// /// Only applicable to input layers. /// - public Shape BatchInputShape { get; set; } + public virtual Shape BatchInputShape { get; set; } - public int BatchSize { get; set; } = -1; + public virtual int BatchSize { get; set; } = -1; /// /// Initial weight values. /// - public float[] Weights { get; set; } + public virtual float[] Weights { get; set; } /// /// Regularizer function applied to the output of the layer(its "activation"). /// - public IRegularizer ActivityRegularizer { get; set; } + public virtual IRegularizer ActivityRegularizer { get; set; } - public bool Autocast { get; set; } + public virtual bool Autocast { get; set; } - public bool IsFromConfig { get; set; } + public virtual bool IsFromConfig { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 0d9e26ac..ad55ff61 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class NodeArgs + public class NodeArgs: IKerasConfig { public ILayer[] InboundLayers { get; set; } public int[] NodeIndices { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs index e2a0e43c..6256fd32 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class OptimizerV2Args + public class OptimizerV2Args: IKerasConfig { public string Name { get; set; } public float LearningRate { get; set; } = 0.001f; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs index c2b48cc2..91ffc205 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs @@ -1,7 +1,10 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class FlattenArgs : LayerArgs + public class FlattenArgs : AutoSerializeLayerArgs { + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs new file mode 100644 index 00000000..1bc13caf --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs @@ -0,0 +1,50 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedActivationJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Activation); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(""); + token.WriteTo(writer); + } + else if (value is not Activation) + { + throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Activation)!.GetType().Name); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + throw new NotImplementedException(); + //var dims = serializer.Deserialize(reader, typeof(string)); + //if (dims is null) + //{ + // throw new ValueError("Cannot deserialize 'null' to `Activation`."); + //} + //return new Shape((long[])(dims!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs new file mode 100644 index 00000000..4e190605 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -0,0 +1,48 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedAxisJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Axis); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(new int[] { }); + token.WriteTo(writer); + } + else if (value is not Axis) + { + throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Axis)!.axis); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var axis = serializer.Deserialize(reader, typeof(long[])); + if (axis is null) + { + throw new ValueError("Cannot deserialize 'null' to `Axis`."); + } + return new Axis((int[])(axis!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs new file mode 100644 index 00000000..1ad19fc8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedNodeConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(NodeConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if (value is not NodeConfig) + { + throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var config = value as NodeConfig; + var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var values = serializer.Deserialize(reader, typeof(object[])) as object[]; + if (values is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + if(values.Length != 3) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + if (values[0] is not string) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); + } + if (values[1] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); + } + if (values[2] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); + } + return new NodeConfig() + { + Name = values[0] as string, + NodeIndex = (int)values[1], + TensorIndex = (int)values[2] + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs new file mode 100644 index 00000000..300cb2f2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -0,0 +1,67 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedShapeJsonConverter: JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Shape); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if(value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if(value is not Shape) + { + throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var shape = (value as Shape)!; + long?[] dims = new long?[shape.ndim]; + for(int i = 0; i < dims.Length; i++) + { + if (shape.dims[i] == -1) + { + dims[i] = null; + } + else + { + dims[i] = shape.dims[i]; + } + } + var token = JToken.FromObject(dims); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + if(dims is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + long[] convertedDims = new long[dims.Length]; + for(int i = 0; i < dims.Length; i++) + { + convertedDims[i] = dims[i] ?? (-1); + } + return new Shape(convertedDims); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 7280594b..6743935c 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -16,23 +16,27 @@ using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { /// /// Specifies the ndim, dtype and shape of every input to a layer. /// - public class InputSpec + public class InputSpec: IKerasConfigable { public int? ndim; + public int? max_ndim; public int? min_ndim; Dictionary axes; Shape shape; + TF_DataType dtype; public int[] AllAxisDim; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, int? min_ndim = null, + int? max_ndim = null, Dictionary axes = null, Shape shape = null) { @@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine axes = new Dictionary(); this.axes = axes; this.min_ndim = min_ndim; + this.max_ndim = max_ndim; this.shape = shape; + this.dtype = dtype; if (ndim == null && shape != null) this.ndim = shape.ndim; @@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine AllAxisDim = axes.Select(x => x.Value).ToArray(); } + public IKerasConfig get_config() + { + return new Config() + { + DType = dtype == TF_DataType.DtInvalid ? null : dtype, + Shape = shape, + Ndim = ndim, + MinNdim = min_ndim, + MaxNdim = max_ndim, + Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) + }; + } + public override string ToString() => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; + + public class Config: IKerasConfig + { + public TF_DataType? DType { get; set; } + public Shape Shape { get; set; } + public int? Ndim { get; set; } + public int? MinNdim { get;set; } + public int? MaxNdim { get;set; } + public IDictionary Axes { get; set; } + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f1ca5632..ebf3358d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,11 +1,12 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Training; namespace Tensorflow.Keras { - public interface ILayer: ITrackable + public interface ILayer: IWithTrackable, IKerasConfigable { string Name { get; } bool Trainable { get; } @@ -19,8 +20,8 @@ namespace Tensorflow.Keras List NonTrainableWeights { get; } Shape OutputShape { get; } Shape BatchInputShape { get; } + TensorShapeConfig BuildInputShape { get; } TF_DataType DType { get; } int count_params(); - LayerArgs get_config(); } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs new file mode 100644 index 00000000..1217e1e5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public interface IKerasConfig + { + } + + public interface IKerasConfigable + { + IKerasConfig get_config(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs index b8b8cab4..4ce290c8 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -1,4 +1,5 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; @@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving { - public class LayerConfig + public class LayerConfig: IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("class_name")] public string ClassName { get; set; } + [JsonProperty("config")] public LayerArgs Config { get; set; } + [JsonProperty("inbound_nodes")] public List InboundNodes { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index abfb235b..cac19180 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -1,15 +1,20 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving { - public class ModelConfig + public class ModelConfig : IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("layers")] public List Layers { get; set; } + [JsonProperty("input_layers")] public List InputLayers { get; set; } + [JsonProperty("output_layers")] public List OutputLayers { get; set; } public override string ToString() diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs index 3132248e..20e2fef5 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -1,10 +1,13 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow.Keras.Saving { - public class NodeConfig + [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] + public class NodeConfig : IKerasConfig { public string Name { get; set; } public int NodeIndex { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs new file mode 100644 index 00000000..7abcfde2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs @@ -0,0 +1,21 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Saving +{ + public class TensorShapeConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } = "TensorShape"; + [JsonProperty("items")] + public long?[] Items { get; set; } + + public static implicit operator Shape(TensorShapeConfig shape) + => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + + public static implicit operator TensorShapeConfig(Shape shape) + => new TensorShapeConfig() { Items = shape.dims.Select(x => x == -1 ? null : x).ToArray() }; + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 6c7189df..709ca9b2 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -14,20 +14,29 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow { - public record Axis(params int[] axis) + [JsonConverter(typeof(CustomizedAxisJsonConverter))] + public class Axis { + public int[] axis { get; set; } public int size => axis == null ? -1 : axis.Length; public bool IsScalar { get; init; } public int this[int index] => axis[index]; + public Axis(params int[] axis) + { + this.axis = axis; + } + public static implicit operator int[]?(Axis axis) => axis?.axis; diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index bc79fefc..ecf73586 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -14,14 +14,17 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; using Tensorflow.NumPy; namespace Tensorflow { + [JsonConverter(typeof(CustomizedShapeJsonConverter))] public class Shape { public int ndim => _dims == null ? -1 : _dims.Length; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs index fdcb5aff..e7e9955c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Constant : IInitializer @@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers T value; bool _verify_shape; + private readonly Dictionary _config; + + public string ClassName => "Constant"; + public IDictionary Config => _config; + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) { this.value = value; this.dtype = dtype; _verify_shape = verify_shape; + + _config = new Dictionary(); + _config["value"] = this.value; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index d97d8830..def1cb7a 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -14,10 +14,17 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class GlorotUniform : VarianceScaling { + private readonly Dictionary _config; + + public override string ClassName => "GlorotUniform"; + public override IDictionary Config => _config; + public GlorotUniform(float scale = 1.0f, string mode = "FAN_AVG", bool uniform = true, @@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers seed: seed, dtype: dtype) { - + _config = new Dictionary(); + _config["seed"] = _seed; } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 50d4d503..9748b100 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -14,10 +14,17 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using System.Collections.Generic; + namespace Tensorflow { public interface IInitializer { + [JsonProperty("class_name")] + string ClassName { get; } + [JsonProperty("config")] + IDictionary Config { get; } Tensor Apply(InitializerArgs args); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 02d3c93b..3077a1e0 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -14,12 +14,19 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Ones : IInitializer { private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "Ones"; + public IDictionary Config => new Dictionary(); + public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) { this.dtype = dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 254a7ee7..cdc1c3ed 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -1,9 +1,14 @@ using System; +using System.Collections.Generic; namespace Tensorflow.Operations.Initializers { public class Orthogonal : IInitializer { + private readonly Dictionary _config; + + public string ClassName => "Orthogonal"; + public IDictionary Config => throw new NotImplementedException(); public Tensor Apply(InitializerArgs args) { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 029b311b..21fa7e2b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomNormal : IInitializer @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomNormal"; + public IDictionary Config => _config; + public RandomNormal(float mean = 0.0f, float stddev = 0.05f, int? seed = null, @@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers this.stddev = stddev; this.seed = seed; this.dtype = dtype; + + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index a49d5921..87404708 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomUniform : IInitializer @@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers private float maxval; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomUniform"; + public IDictionary Config => _config; + public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) { this.dtype = dtype; this.minval = minval; this.maxval = maxval; this.seed = seed; + + _config = new Dictionary(); + _config["minval"] = this.minval; + _config["maxval"] = this.maxval; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 048c11e7..c1c3e999 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class TruncatedNormal : IInitializer @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "TruncatedNormal"; + public IDictionary Config => _config; + public TruncatedNormal(float mean = 0.0f, float stddev = 1.0f, int? seed = null, @@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers this.stddev = stddev; this.seed = seed; this.dtype = dtype; + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index d313f4c9..f104e8e8 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -15,7 +15,9 @@ ******************************************************************************/ using System; +using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; namespace Tensorflow.Operations.Initializers { @@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers protected int? _seed; protected TF_DataType _dtype; protected bool _uniform; + private readonly Dictionary _config; + + public virtual string ClassName => "VarianceScaling"; + + public virtual IDictionary Config => _config; public VarianceScaling(float factor = 2.0f, string mode = "FAN_IN", @@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers _seed = seed; _dtype = dtype; _uniform = uniform; + + _config = new(); + _config["scale"] = _scale; + _config["mode"] = _mode; + _config["distribution"] = _distribution; + _config["seed"] = _seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index 5d045292..c4ed25a1 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Zeros : IInitializer @@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers Shape shape; TF_DataType dtype; + public string ClassName => "Zeros"; + public IDictionary Config => new Dictionary(); + public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { this.shape = shape; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 734f2608..c29ed47b 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -20,6 +20,7 @@ using Tensorflow.Keras; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; @@ -76,6 +77,8 @@ namespace Tensorflow public Shape BatchInputShape => throw new NotImplementedException(); + public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); + public TF_DataType DType => throw new NotImplementedException(); protected bool built = false; public bool Built => built; @@ -144,7 +147,7 @@ namespace Tensorflow throw new NotImplementedException(); } - public LayerArgs get_config() + public IKerasConfig get_config() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 0ebe61d0..7068ed47 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io + diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 372ac676..deeb9e4b 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -202,6 +202,24 @@ namespace Tensorflow _ => type.ToString() }; + public static string as_python_name(this TF_DataType type) + => type switch + { + TF_DataType.TF_STRING => "str", + TF_DataType.TF_UINT8 => "uint8", + TF_DataType.TF_INT8 => "int8", + TF_DataType.TF_UINT32 => "uint32", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_UINT64 => "uint64", + TF_DataType.TF_INT64 => "int64", + TF_DataType.TF_FLOAT => "float32", + TF_DataType.TF_DOUBLE => "float64", + TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", + TF_DataType.TF_VARIANT => "variant", + _ => type.ToString() + }; + public static int get_datatype_size(this TF_DataType type) => type.as_base_dtype() switch { diff --git a/src/TensorFlowNET.Core/Training/ITrackable.cs b/src/TensorFlowNET.Core/Training/IWithTrackable.cs similarity index 82% rename from src/TensorFlowNET.Core/Training/ITrackable.cs rename to src/TensorFlowNET.Core/Training/IWithTrackable.cs index e4ef2c8f..87eda879 100644 --- a/src/TensorFlowNET.Core/Training/ITrackable.cs +++ b/src/TensorFlowNET.Core/Training/IWithTrackable.cs @@ -5,7 +5,7 @@ using Tensorflow.Train; namespace Tensorflow.Training { - public interface ITrackable + public interface IWithTrackable { Trackable GetTrackable(); } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 434d51b6..132571f2 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -26,7 +26,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class Trackable: ITrackable + public abstract class Trackable: IWithTrackable { /// /// Corresponding to tensorflow/python/trackable/constants.py diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index a221444b..3aeb3200 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine { public partial class Functional { - public ModelConfig get_config() + public override IKerasConfig get_config() { return get_network_config(); } @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine { Name = name }; - + var node_conversion_map = new Dictionary(); foreach (var layer in _self_tracked_trackables) { @@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine } var layer_configs = new List(); - foreach (var layer in _self_tracked_trackables) + using (SharedObjectSavingScope.Enter()) { - var filtered_inbound_nodes = new List(); - foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + foreach (var layer in _self_tracked_trackables) { - var node_key = _make_node_key(layer.Name, original_node_index); - if (NetworkNodes.Contains(node_key) && !node.is_input) + var filtered_inbound_nodes = new List(); + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) { - var node_data = node.serialize(_make_node_key, node_conversion_map); - filtered_inbound_nodes.append(node_data); + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key) && !node.is_input) + { + var node_data = node.serialize(_make_node_key, node_conversion_map); + filtered_inbound_nodes.append(node_data); + } } - } - var layer_config = generic_utils.serialize_layer_to_config(layer); - layer_config.Name = layer.Name; - layer_config.InboundNodes = filtered_inbound_nodes; - layer_configs.Add(layer_config); + var layer_config = generic_utils.serialize_layer_to_config(layer); + layer_config.Name = layer.Name; + layer_config.InboundNodes = filtered_inbound_nodes; + layer_configs.Add(layer_config); + } } config.Layers = layer_configs; diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 7c8812ad..44eaef53 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -70,6 +70,7 @@ namespace Tensorflow.Keras.Engine this.inputs = inputs; this.outputs = outputs; built = true; + _buildInputShape = inputs.shape; if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); @@ -357,5 +358,22 @@ namespace Tensorflow.Keras.Engine return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); } + + protected override void _init_set_name(string name, bool zero_based = true) + { + if (string.IsNullOrEmpty(name)) + { + string class_name = GetType().Name; + if (this.GetType() == typeof(Functional)) + { + class_name = "Model"; + } + this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); + } + else + { + this.name = name; + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index a2f92ba8..31b37d68 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -61,6 +61,7 @@ namespace Tensorflow.Keras.Engine /// Provides information about which inputs are compatible with the layer. /// protected InputSpec inputSpec; + public InputSpec InputSpec => inputSpec; bool dynamic = true; public bool SupportsMasking { get; set; } protected List _trainable_weights; @@ -79,6 +80,8 @@ namespace Tensorflow.Keras.Engine protected bool computePreviousMask; protected List updates; public Shape BatchInputShape => args.BatchInputShape; + protected TensorShapeConfig _buildInputShape = null; + public TensorShapeConfig BuildInputShape => _buildInputShape; List inboundNodes; public List InboundNodes => inboundNodes; @@ -223,6 +226,7 @@ namespace Tensorflow.Keras.Engine public virtual void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } @@ -310,7 +314,7 @@ namespace Tensorflow.Keras.Engine public List Variables => weights; - public virtual LayerArgs get_config() + public virtual IKerasConfig get_config() => args; } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 59b205e4..85da920e 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; @@ -30,7 +31,10 @@ namespace Tensorflow.Keras.Engine } else { - KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + using (SharedObjectSavingScope.Enter()) + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } } } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 6e790a26..45f64720 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers { { throw new ValueError("Alpha must be a number greater than 0."); } + _buildInputShape = input_shape; built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index aba175de..2fd2caee 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers { } public override void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index b12d7dee..1ef8d0e5 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers { // SELU has no arguments } public override void build(Shape input_shape) { - if ( alpha < 0f ) { - throw new ValueError("Alpha must be a number greater than 0."); - } - built = true; + if ( alpha < 0f ) { + throw new ValueError("Alpha must be a number greater than 0."); + } + _buildInputShape = input_shape; + built = true; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs index 6f6dd7e8..c5131630 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -4,6 +4,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers return scores; } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; //var config = new Dictionary { // { // "use_scale", diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 3f618b5d..1348e19c 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -5,6 +5,7 @@ using static Tensorflow.KerasApi; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; /// /// Base class for attention layers that can be used in sequence DNN/CNN models. @@ -252,6 +253,6 @@ namespace Tensorflow.Keras.Layers return tf.logical_and(x, y); } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; } } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index e0a337ca..b8286be6 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -49,6 +49,7 @@ namespace Tensorflow.Keras.Layers initializer: bias_initializer, trainable: true); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 912a429b..933aa9cf 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -98,6 +98,7 @@ namespace Tensorflow.Keras.Layers name: tf_op_name); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index e4c22745..ca8007d0 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -43,6 +43,7 @@ namespace Tensorflow.Keras.Layers public override void build(Shape input_shape) { + _buildInputShape = input_shape; var last_dim = input_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = (int)last_dim; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 79f4e5ce..606f387b 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -62,6 +62,7 @@ namespace Tensorflow.Keras.Layers name: "embeddings"); tf.Context.graph_mode(); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs index 45f5bf0f..44b338c2 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs @@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Layers { throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); } built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs index 6cb03e1e..1f33ee3a 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs @@ -13,7 +13,8 @@ namespace Tensorflow.Keras.Layers { this.args = args; } public override void build(Shape input_shape) { - built = true; + built = true; + _buildInputShape = input_shape; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs index 2d6751bf..838a5043 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs @@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers { } public override void build(Shape input_shape) { - built = true; + built = true; + _buildInputShape = input_shape; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 50c66be7..c1ec0ddc 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -300,7 +300,8 @@ namespace Tensorflow.Keras.Layers => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName("linear") + Activation = GetActivationByName("linear"), + ActivationName = "linear" }); /// @@ -321,6 +322,7 @@ namespace Tensorflow.Keras.Layers { Units = units, Activation = GetActivationByName(activation), + ActivationName = activation, InputShape = input_shape }); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index 5f821760..da7e857a 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -37,6 +37,7 @@ namespace Tensorflow.Keras.Layers }).ToArray(); shape_set.Add(shape); }*/ + _buildInputShape = input_shape; } protected override Tensors _merge_function(Tensors inputs) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 0363d58f..3cd43af9 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -17,6 +17,7 @@ namespace Tensorflow.Keras.Layers public override void build(Shape input_shape) { // output_shape = input_shape.dims[1^]; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index dac92f81..c0b16c81 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -118,6 +118,7 @@ namespace Tensorflow.Keras.Layers throw new NotImplementedException("build when renorm is true"); built = true; + _buildInputShape = input_shape; } public override Shape ComputeOutputShape(Shape input_shape) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index 5eebd735..e19b9c30 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -81,6 +81,7 @@ namespace Tensorflow.Keras.Layers _fused = _fused_can_be_used(ndims); built = true; + _buildInputShape = input_shape; } bool _fused_can_be_used(int ndims) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index 868506b6..8e7a19a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -24,6 +24,7 @@ namespace Tensorflow.Keras.Layers { permute = new int[input_shape.rank]; dims.CopyTo(permute, 1); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index c8366ff4..38abe2a7 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -18,6 +18,7 @@ namespace Tensorflow.Keras.Layers.Rnn public override void build(Shape input_shape) { var input_dim = input_shape[-1]; + _buildInputShape = input_shape; kernel = add_weight("kernel", (input_shape[-1], args.Units), initializer: args.KernelInitializer diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index eead274a..20962df1 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -4,6 +4,7 @@ using System.ComponentModel; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Rnn { @@ -136,7 +137,7 @@ namespace Tensorflow.Keras.Layers.Rnn // self.built = True } - public override LayerArgs get_config() + public override IKerasConfig get_config() { throw new NotImplementedException(); //def get_config(self): diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 4ff8f02f..9d1c9609 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -79,7 +79,7 @@ public partial class KerasSavedModelUtils var path = node_paths[node]; string node_path; - if (path is null) + if (path is null || path.Count() == 0) { node_path = "root"; } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 655127af..8675ea65 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Newtonsoft.Json; using Newtonsoft.Json.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; @@ -85,31 +86,38 @@ public class LayerSavedModelSaver: SavedModelSaver JObject metadata = new JObject(); metadata["name"] = _layer.Name; metadata["trainable"] = _layer.Trainable; - // metadata["expects_training_arg"] = _obj._expects_training_arg; - // metadata["dtype"] = policy.serialize(_obj._dtype_policy) + // TODO: implement `expects_training_arg`. + metadata["expects_training_arg"] = false; + metadata["dtype"] = _layer.DType.as_python_name(); metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; metadata["autocast"] = _layer.AutoCast; - var temp = JObject.FromObject(get_serialized(_layer)); - metadata.Merge(temp, new JsonMergeSettings + if(_layer.InputSpec is not null) + { + metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec); + } + + metadata.Merge(get_serialized(_layer), new JsonMergeSettings { // Handle conflicts by using values from obj2 MergeArrayHandling = MergeArrayHandling.Merge }); // skip the check of `input_spec` and `build_input_shape` for the lack of members. // skip the check of `activity_regularizer` for the type problem. + if(_layer.BuildInputShape is not null) + { + metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape); + } return metadata.ToString(); } } - public static IDictionary get_serialized(Layer obj) + public static JObject get_serialized(Layer obj) { - // TODO: complete the implmentation (need to revise `get_config`). - return new Dictionary(); - //return generic_utils.serialize_keras_object(obj); + return generic_utils.serialize_keras_object(obj); } } @@ -135,18 +143,19 @@ public class InputLayerSavedModelSaver: SavedModelSaver { get { - if(_obj is not Layer) + if(_obj is not InputLayer) { throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); } - var layer = (Layer)_obj; + var layer = (InputLayer)_obj; + var config = (layer.get_config() as InputLayerArgs)!; var info = new { class_name = layer.GetType().Name, name = layer.Name, dtype = layer.DType, - //sparse = layer.sparse, - //ragged = layer.ragged, + sparse = config.Sparse, + ragged = config.Ragged, batch_input_shape = layer.BatchInputShape, config = layer.get_config() }; diff --git a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs deleted file mode 100644 index 4c2ecc0d..00000000 --- a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Tensorflow.Keras.Saving -{ - public class TensorShapeConfig - { - public string ClassName { get; set; } - public int?[] Items { get; set; } - - public static implicit operator Shape(TensorShapeConfig shape) - => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); - } -} diff --git a/src/TensorFlowNET.Keras/Saving/serialization.cs b/src/TensorFlowNET.Keras/Saving/serialization.cs new file mode 100644 index 00000000..d5e46d11 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/serialization.cs @@ -0,0 +1,125 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving +{ + // TODO: make it thread safe. + public class SharedObjectSavingScope: IDisposable + { + private class WeakReferenceEqualityComparer: IEqualityComparer> + { + public bool Equals(WeakReference x, WeakReference y) + { + if(!x.TryGetTarget(out var tx)) + { + return false; + } + if(!y.TryGetTarget(out var ty)) + { + return false; + } + return tx.Equals(ty); + } + public int GetHashCode(WeakReference obj) + { + if (!obj.TryGetTarget(out var w)) + { + return 0; + } + return w.GetHashCode(); + } + } + private static SharedObjectSavingScope? _instance = null; + private readonly Dictionary, int> _shared_object_ids= new Dictionary, int>(); + private int _currentId = 0; + /// + /// record how many times the scope is nested. + /// + private int _nestedDepth = 0; + private SharedObjectSavingScope() + { + + } + + public static SharedObjectSavingScope Enter() + { + if(_instance is not null) + { + _instance._nestedDepth++; + return _instance; + } + else + { + _instance = new SharedObjectSavingScope(); + _instance._nestedDepth++; + return _instance; + } + } + + public static SharedObjectSavingScope GetScope() + { + return _instance; + } + + public int GetId(object? obj) + { + if(obj is null) + { + return _currentId++; + } + var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference(obj))); + if (maybe_key is not null) + { + return _shared_object_ids[maybe_key]; + } + _shared_object_ids[new WeakReference(obj)] = _currentId++; + return _currentId; + } + + public void Dispose() + { + _nestedDepth--; + if(_nestedDepth== 0) + { + _instance = null; + } + } + } + + public static class serialize_utils + { + public static readonly string SHARED_OBJECT_KEY = "shared_object_id"; + /// + /// Returns the serialization of the class with the given config. + /// + /// + /// + /// + /// + /// + public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null) + { + JObject res = new JObject(); + res["class_name"] = class_name; + res["config"] = config; + + if(shared_object_id is not null) + { + res[SHARED_OBJECT_KEY] = shared_object_id!; + } + + var scope = SharedObjectSavingScope.GetScope(); + if(scope is not null && obj is not null) + { + res[SHARED_OBJECT_KEY] = scope.GetId(obj); + } + + return res; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 1e6ce409..d845f3ca 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils } /// - /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. + /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.) /// /// /// diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 68903eb2..730a33e3 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -14,10 +14,14 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Utils @@ -32,13 +36,21 @@ namespace Tensorflow.Keras.Utils public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); + Debug.Assert(config is LayerArgs); return new LayerConfig { - Config = config, + Config = config as LayerArgs, ClassName = instance.GetType().Name }; } + public static JObject serialize_keras_object(IKerasConfigable instance) + { + var config = JToken.FromObject(instance.get_config()); + // TODO: change the class_name to registered name, instead of system class name. + return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs index 0a1098af..67e8ff79 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs @@ -1,6 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.Keras.Engine; +using System.Diagnostics; using static Tensorflow.KerasApi; +using Tensorflow.Keras.Saving; namespace TensorFlowNET.Keras.UnitTest { @@ -15,7 +17,8 @@ namespace TensorFlowNET.Keras.UnitTest { var model = GetFunctionalModel(); var config = model.get_config(); - var new_model = keras.models.from_config(config); + Debug.Assert(config is ModelConfig); + var new_model = keras.models.from_config(config as ModelConfig); Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs index 0f34ff10..90d0a48a 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -15,17 +15,14 @@ using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; +using Tensorflow.Operations; namespace TensorFlowNET.Keras.UnitTest; -// class MNISTLoader -// { -// public MNISTLoader() -// { -// var mnist = new MnistModelLoader() -// -// } -// } +public static class AutoGraphExtension +{ + +} [TestClass] public class SaveTest @@ -42,6 +39,8 @@ public class SaveTest model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); + var g = ops.get_default_graph(); + var data_loader = new MnistModelLoader(); var num_epochs = 1; var batch_size = 50; @@ -50,11 +49,34 @@ public class SaveTest { TrainDir = "mnist", OneHot = false, - ValidationSize = 0, + ValidationSize = 50000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); } + + [TestMethod] + public void Temp() + { + var graph = new Graph(); + var g = graph.as_default(); + //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor"); + var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa"); + var wrapped_func = tf.autograph.to_graph(func); + var res = wrapped_func(input_tensor); + g.Exit(); + } + + private Tensor func(Tensor tensor) + { + return gen_ops.neg(tensor); + //return array_ops.identity(tensor); + //tf.device("cpu:0"); + //using (ops.control_dependencies(new object[] { res.op })) + //{ + // return array_ops.identity(tensor); + //} + } } \ No newline at end of file