From e0f1ac04159d599f9b8fe60096562de410435659 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 8 Feb 2019 17:13:37 -0600 Subject: [PATCH] Add many Saveable classes. --- docs/source/Train.md | 9 +- src/TensorFlowNET.Core/IPyClass.cs | 21 + .../Operations/OpDefLibrary.cs | 48 ++- .../Operations/gen_array_ops.cs | 13 + .../Operations/gen_io_ops.py.cs | 18 + src/TensorFlowNET.Core/Protobuf/Saver.cs | 401 ++++++++++++++++++ src/TensorFlowNET.Core/Python.cs | 16 +- .../Train/Saving/BaseSaverBuilder.cs | 98 +++++ .../Train/Saving/BulkSaverBuilder.cs | 14 + .../Train/Saving/ISaverBuilder.cs | 24 ++ .../Train/Saving/ReferenceVariableSaveable.cs | 19 + .../Train/Saving/SaveSpec.cs | 32 ++ .../Train/Saving/SaveableObject.cs | 31 ++ src/TensorFlowNET.Core/Train/Saving/Saver.cs | 113 +++++ .../Train/Saving/saveable_object_util.py.cs | 102 +++++ src/TensorFlowNET.Core/Train/tf.optimizers.cs | 7 +- .../Variables/variables.py.cs | 20 + src/TensorFlowNET.Core/ops.GraphKeys.cs | 5 + src/TensorFlowNET.Core/ops.py.cs | 4 + test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 10 +- 20 files changed, 992 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/IPyClass.cs create mode 100644 src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs create mode 100644 src/TensorFlowNET.Core/Protobuf/Saver.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/Saver.cs create mode 100644 src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs diff --git a/docs/source/Train.md b/docs/source/Train.md index c71b31c7..85d441ba 100644 --- a/docs/source/Train.md +++ b/docs/source/Train.md @@ -2,4 +2,11 @@ ### Saver -The `tf.train.saver` class provides methods to save and restore models. \ No newline at end of file +The `tf.train.saver` class provides methods to save and restore models. + + + +### Saver Builder + +##### Bulk Saver Builder + diff --git a/src/TensorFlowNET.Core/IPyClass.cs b/src/TensorFlowNET.Core/IPyClass.cs new file mode 100644 index 00000000..fd08ab82 --- /dev/null +++ b/src/TensorFlowNET.Core/IPyClass.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface IPyClass + { + /// + /// Called when the instance is created. + /// + /// + void __init__(IPyClass self, dynamic args); + + void __enter__(IPyClass self); + + void __exit__(IPyClass self); + + void __del__(IPyClass self); + } +} diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 24b39239..406dac9d 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -116,6 +116,10 @@ namespace Tensorflow values = new Tensor[] { keywords[input_name] as Tensor }; } + inputs.AddRange(values as Tensor[]); + base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); + input_types.AddRange(base_types); + if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { if (attrs.ContainsKey(input_arg.NumberAttr)) @@ -144,10 +148,32 @@ namespace Tensorflow var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); } } + else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) + { + var attr_value = base_types[0]; + if (attrs.ContainsKey(input_arg.TypeAttr)) + { - inputs.AddRange(values as Tensor[]); - base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); - input_types.AddRange(base_types); + } + else + { + attrs[input_arg.TypeAttr] = attr_value; + inferred_from[input_arg.TypeAttr] = input_name; + } + } + else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + { + var attr_value = base_types; + if (attrs.ContainsKey(input_arg.TypeListAttr)) + { + + } + else + { + attrs[input_arg.TypeListAttr] = attr_value; + inferred_from[input_arg.TypeListAttr] = input_name; + } + } } // Process remaining attrs @@ -213,6 +239,11 @@ namespace Tensorflow case "type": attr_value.Type = _MakeType((TF_DataType)value, attr_def); break; + case "list(type)": + if (attr_value.List == null) + attr_value.List = new AttrValue.Types.ListValue(); + attr_value.List.Type.AddRange((value as IList).Select(x => _MakeType(x, attr_def))); + break; case "bool": attr_value.B = (bool)value; break; @@ -225,9 +256,14 @@ namespace Tensorflow throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); break; case "shape": - attr_value.Shape = value == null ? - attr_def.DefaultValue.Shape : - tensor_util.as_shape((long[])value); + if (value == null && attr_def.DefaultValue != null) + attr_value.Shape = attr_def.DefaultValue.Shape; + + if(value is TensorShape val1) + attr_value.Shape = val1.as_proto(); + else if(value is long[] val2) + attr_value.Shape = tensor_util.as_shape(val2); + break; default: throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index efc8e4e0..b3a3e607 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -106,6 +106,19 @@ namespace Tensorflow throw new NotImplementedException("where"); } + /// + /// A placeholder op that passes through `input` when its output is not fed. + /// + /// The default value to produce when output is not fed. + /// + /// + /// + public static Tensor placeholder_with_default(T input, TensorShape shape, string name = "") + { + var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); + return _op.outputs[0]; + } + public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "") { var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs new file mode 100644 index 00000000..ce57e834 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class gen_io_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = "") + { + var _op = _op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); + + return _op; + } + } +} diff --git a/src/TensorFlowNET.Core/Protobuf/Saver.cs b/src/TensorFlowNET.Core/Protobuf/Saver.cs new file mode 100644 index 00000000..e031f2f6 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Saver.cs @@ -0,0 +1,401 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: saver.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from saver.proto + public static partial class SaverReflection { + + #region Descriptor + /// File descriptor for saver.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SaverReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CgtzYXZlci5wcm90bxIKdGVuc29yZmxvdyKeAgoIU2F2ZXJEZWYSHAoUZmls", + "ZW5hbWVfdGVuc29yX25hbWUYASABKAkSGAoQc2F2ZV90ZW5zb3JfbmFtZRgC", + "IAEoCRIXCg9yZXN0b3JlX29wX25hbWUYAyABKAkSEwoLbWF4X3RvX2tlZXAY", + "BCABKAUSDwoHc2hhcmRlZBgFIAEoCBIlCh1rZWVwX2NoZWNrcG9pbnRfZXZl", + "cnlfbl9ob3VycxgGIAEoAhI9Cgd2ZXJzaW9uGAcgASgOMiwudGVuc29yZmxv", + "dy5TYXZlckRlZi5DaGVja3BvaW50Rm9ybWF0VmVyc2lvbiI1ChdDaGVja3Bv", + "aW50Rm9ybWF0VmVyc2lvbhIKCgZMRUdBQ1kQABIGCgJWMRABEgYKAlYyEAJC", + "ZQoTb3JnLnRlbnNvcmZsb3cudXRpbEILU2F2ZXJQcm90b3NQAVo8Z2l0aHVi", + "LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl", + "L3Byb3RvYnVm+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the configuration of a Saver. + /// + public sealed partial class SaverDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SaverDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SaverDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SaverDef(SaverDef other) : this() { + filenameTensorName_ = other.filenameTensorName_; + saveTensorName_ = other.saveTensorName_; + restoreOpName_ = other.restoreOpName_; + maxToKeep_ = other.maxToKeep_; + sharded_ = other.sharded_; + keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_; + version_ = other.version_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SaverDef Clone() { + return new SaverDef(this); + } + + /// Field number for the "filename_tensor_name" field. + public const int FilenameTensorNameFieldNumber = 1; + private string filenameTensorName_ = ""; + /// + /// The name of the tensor in which to specify the filename when saving or + /// restoring a model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string FilenameTensorName { + get { return filenameTensorName_; } + set { + filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "save_tensor_name" field. + public const int SaveTensorNameFieldNumber = 2; + private string saveTensorName_ = ""; + /// + /// The operation to run when saving a model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string SaveTensorName { + get { return saveTensorName_; } + set { + saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "restore_op_name" field. + public const int RestoreOpNameFieldNumber = 3; + private string restoreOpName_ = ""; + /// + /// The operation to run when restoring a model checkpoint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string RestoreOpName { + get { return restoreOpName_; } + set { + restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "max_to_keep" field. + public const int MaxToKeepFieldNumber = 4; + private int maxToKeep_; + /// + /// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxToKeep { + get { return maxToKeep_; } + set { + maxToKeep_ = value; + } + } + + /// Field number for the "sharded" field. + public const int ShardedFieldNumber = 5; + private bool sharded_; + /// + /// Shard the save files, one per device that has Variable nodes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Sharded { + get { return sharded_; } + set { + sharded_ = value; + } + } + + /// Field number for the "keep_checkpoint_every_n_hours" field. + public const int KeepCheckpointEveryNHoursFieldNumber = 6; + private float keepCheckpointEveryNHours_; + /// + /// How often to keep an additional checkpoint. If not specified, only the last + /// "max_to_keep" checkpoints are kept; if specified, in addition to keeping + /// the last "max_to_keep" checkpoints, an additional checkpoint will be kept + /// for every n hours of training. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float KeepCheckpointEveryNHours { + get { return keepCheckpointEveryNHours_; } + set { + keepCheckpointEveryNHours_ = value; + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 7; + private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version { + get { return version_; } + set { + version_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SaverDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SaverDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FilenameTensorName != other.FilenameTensorName) return false; + if (SaveTensorName != other.SaveTensorName) return false; + if (RestoreOpName != other.RestoreOpName) return false; + if (MaxToKeep != other.MaxToKeep) return false; + if (Sharded != other.Sharded) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false; + if (Version != other.Version) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode(); + if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode(); + if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode(); + if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode(); + if (Sharded != false) hash ^= Sharded.GetHashCode(); + if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours); + if (Version != 0) hash ^= Version.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (FilenameTensorName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FilenameTensorName); + } + if (SaveTensorName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(SaveTensorName); + } + if (RestoreOpName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(RestoreOpName); + } + if (MaxToKeep != 0) { + output.WriteRawTag(32); + output.WriteInt32(MaxToKeep); + } + if (Sharded != false) { + output.WriteRawTag(40); + output.WriteBool(Sharded); + } + if (KeepCheckpointEveryNHours != 0F) { + output.WriteRawTag(53); + output.WriteFloat(KeepCheckpointEveryNHours); + } + if (Version != 0) { + output.WriteRawTag(56); + output.WriteEnum((int) Version); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (FilenameTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName); + } + if (SaveTensorName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName); + } + if (RestoreOpName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName); + } + if (MaxToKeep != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep); + } + if (Sharded != false) { + size += 1 + 1; + } + if (KeepCheckpointEveryNHours != 0F) { + size += 1 + 4; + } + if (Version != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SaverDef other) { + if (other == null) { + return; + } + if (other.FilenameTensorName.Length != 0) { + FilenameTensorName = other.FilenameTensorName; + } + if (other.SaveTensorName.Length != 0) { + SaveTensorName = other.SaveTensorName; + } + if (other.RestoreOpName.Length != 0) { + RestoreOpName = other.RestoreOpName; + } + if (other.MaxToKeep != 0) { + MaxToKeep = other.MaxToKeep; + } + if (other.Sharded != false) { + Sharded = other.Sharded; + } + if (other.KeepCheckpointEveryNHours != 0F) { + KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours; + } + if (other.Version != 0) { + Version = other.Version; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FilenameTensorName = input.ReadString(); + break; + } + case 18: { + SaveTensorName = input.ReadString(); + break; + } + case 26: { + RestoreOpName = input.ReadString(); + break; + } + case 32: { + MaxToKeep = input.ReadInt32(); + break; + } + case 40: { + Sharded = input.ReadBool(); + break; + } + case 53: { + KeepCheckpointEveryNHours = input.ReadFloat(); + break; + } + case 56: { + version_ = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the SaverDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// A version number that identifies a different on-disk checkpoint format. + /// Usually, each subclass of BaseSaverBuilder works with a particular + /// version/format. However, it is possible that the same builder may be + /// upgraded to support a newer checkpoint format in the future. + /// + public enum CheckpointFormatVersion { + /// + /// Internal legacy format. + /// + [pbr::OriginalName("LEGACY")] Legacy = 0, + /// + /// Deprecated format: tf.Saver() which works with tensorflow::table::Table. + /// + [pbr::OriginalName("V1")] V1 = 1, + /// + /// Current format: more efficient. + /// + [pbr::OriginalName("V2")] V2 = 2, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index d351b91e..a47ac262 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -15,6 +15,15 @@ namespace Tensorflow Console.WriteLine(obj.ToString()); } + public static T New(object args) where T : IPyClass + { + var instance = Activator.CreateInstance(); + + instance.__init__(instance, args); + + return instance; + } + public static void with(IPython py, Action action) { try @@ -63,7 +72,7 @@ namespace Tensorflow catch (Exception ex) { Console.WriteLine(ex.ToString()); - throw ex; + return default(TOut); } finally { @@ -97,4 +106,9 @@ namespace Tensorflow void __exit__(); } + + public class PyObject where T : IPyClass + { + public T Instance { get; set; } + } } diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs new file mode 100644 index 00000000..edbb8010 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -0,0 +1,98 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public class BaseSaverBuilder + { + protected int _write_version; + + public BaseSaverBuilder(int write_version = 2) + { + _write_version = write_version; + } + + public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) + { + var tensor_names = new List(); + var tensors = new List(); + var tensor_slices = new List(); + + foreach (var saveable in saveables) + { + foreach(var spec in saveable.specs) + { + tensor_names.Add(spec.name); + tensors.Add(spec.tensor); + tensor_slices.Add(spec.slice_spec); + } + } + + if (_write_version == 2) + { + return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); + } + else + { + throw new NotImplementedException("_write_version v1"); + } + } + + public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) + { + throw new NotImplementedException(); + } + + public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, + bool reshape = false, + bool sharded = false, + int max_to_keep = 5, + double keep_checkpoint_every_n_hours = 10000, + string name = "", + bool restore_sequentially = false, + string filename = "model", + bool build_save = true, + bool build_restore = true) + { + if (!build_save || !build_restore) + throw new ValueError("save and restore operations need to be built together " + + " when eager execution is not enabled."); + + var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables); + + if (max_to_keep < 0) + max_to_keep = 0; + + Python.with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => + { + name = scope; + + // Add a placeholder string tensor for the filename. + var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); + filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); + // Keep the name "Const" for backwards compatibility. + + // Add the save ops. + if (sharded) + { + + } + else + { + if (build_save) + _AddSaveOps(filename_tensor, saveables); + } + }); + + throw new NotImplementedException(""); + } + + public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) + { + var save = save_op(filename_tensor, saveables); + return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs new file mode 100644 index 00000000..b99b75f0 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder + { + public BulkSaverBuilder(int write_version = 2) : base(write_version) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs new file mode 100644 index 00000000..ed69919e --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface ISaverBuilder + { + Operation save_op(Tensor filename_tensor, SaveableObject[] saveables); + + Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially); + + SaverDef _build_internal(RefVariable[] names_to_saveables, + bool reshape = false, + bool sharded = false, + int max_to_keep = 5, + double keep_checkpoint_every_n_hours = 10000, + string name = "", + bool restore_sequentially = false, + string filename = "model", + bool build_save = true, + bool build_restore = true); + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs b/src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs new file mode 100644 index 00000000..583ef889 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class ReferenceVariableSaveable : SaveableObject + { + private SaveSpec _spec; + + public ReferenceVariableSaveable(Tensor var, string slice_spec, string name) + { + _spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype); + op = var; + specs = new SaveSpec[] { _spec }; + this.name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs new file mode 100644 index 00000000..1e932209 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Class used to describe tensor slices that need to be saved. + /// + public class SaveSpec + { + private Tensor _tensor; + public Tensor tensor => _tensor; + + private string _slice_spec; + public string slice_spec => _slice_spec; + + private string _name; + public string name => _name; + + private TF_DataType _dtype; + public TF_DataType dtype => _dtype; + + public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) + { + _tensor = tensor; + _slice_spec = slice_spec; + _name = name; + _dtype = dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs new file mode 100644 index 00000000..79be269b --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class SaveableObject + { + public Tensor op; + public SaveSpec[] specs; + public string name; + public string device; + + public SaveableObject() + { + + } + + public SaveableObject(Tensor var, string slice_spec, string name) + { + + } + + public SaveableObject(Tensor op, SaveSpec[] specs, string name) + { + this.op = op; + this.specs = specs; + this.name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs new file mode 100644 index 00000000..5e7d6333 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Saves and restores variables. + /// + public class Saver + { + private RefVariable[] _var_list; + private bool _reshape; + private bool _sharded; + private int _max_to_keep; + private double _keep_checkpoint_every_n_hours; + private string _name; + private bool _restore_sequentially; + private SaverDef _saver_def; + private ISaverBuilder _builder; + private bool _allow_empty; + private bool _is_built; + private int _write_version; + private bool _pad_step_number; + private string _filename; + private bool _is_empty; + + public Saver(RefVariable[] var_list = null, + bool reshape = false, + bool sharded = false, + int max_to_keep = 5, + double keep_checkpoint_every_n_hours = 10000, + string name = "", + bool restore_sequentially = false, + SaverDef saver_def = null, + ISaverBuilder builder = null, + bool defer_build = false, + bool allow_empty = false, + int write_version = 2, + bool pad_step_number = false, + bool save_relative_paths = false, + string filename = "") + { + _var_list = var_list; + _reshape = reshape; + _sharded = sharded; + _max_to_keep = max_to_keep; + _keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours; + _name = name; + _restore_sequentially = restore_sequentially; + _builder = builder; + _is_built = false; + _allow_empty = allow_empty; + _write_version = write_version; + _pad_step_number = pad_step_number; + + if (!defer_build) + build(); + } + + public void build() + { + _build(_filename, build_save: true, build_restore: true); + } + + private void _build(string checkpoint_path, bool build_save, bool build_restore) + { + if (_is_built) return; + + _is_built = true; + + if (_saver_def == null) + { + if (_builder == null) + _builder = new BulkSaverBuilder(_write_version); + + if (_var_list == null) + _var_list = variables._all_saveable_objects(); + + if (_var_list == null || _var_list.Length == 0) + { + if (_allow_empty) + { + _is_empty = true; + return; + } + else + { + throw new ValueError("No variables to save"); + } + } + _is_empty = false; + + _saver_def = _builder._build_internal(_var_list, + reshape: _reshape, + sharded: _sharded, + max_to_keep: _max_to_keep, + keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours, + name: _name, + restore_sequentially: _restore_sequentially, + filename: checkpoint_path, + build_save: build_save, + build_restore: build_restore); + } + else if (_saver_def != null && !string.IsNullOrEmpty(_name)) + { + throw new NotImplementedException(""); + } + + + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs new file mode 100644 index 00000000..3a244e90 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs @@ -0,0 +1,102 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow +{ + public class saveable_object_util + { + /// + /// Returns the variables and names that will be used for a Saver. + /// + /// + /// + public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables) + { + var names_to_saveables_dict = op_list_to_dict(names_to_saveables); + var saveables = new List(); + var seen_ops = new List(); + + foreach (var item in names_to_saveables_dict) + { + foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + return saveables.ToArray(); + } + + private static void _add_saveable(List saveables, List seen_ops, T saveable) where T : SaveableObject + { + if (seen_ops.Contains(saveable.op)) + throw new ValueError($"The same saveable will be restored with two names: {saveable.name}"); + + saveables.Add(saveable); + seen_ops.Add(saveable.op); + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Tensor op, string name) + { + if (false) + { + + } + else + { + ops.init_scope(); + var variable = ops.internal_convert_to_tensor(op, as_ref: true); + if (variable.op.type == "VariableV2") + yield return new ReferenceVariableSaveable(variable, "", name); + } + } + + public static Dictionary op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true) + { + op_list = op_list.OrderBy(x => x.name).ToArray(); + var names_to_saveables = new Dictionary(); + + foreach(var var in op_list) + { + if (false) + { + throw new NotImplementedException("op_list_to_dict"); + } + else + { + if(false) // eager + { + + } + else + { + string name = ""; + Tensor tensor = null; + + if (convert_variable_to_tensor) + { + tensor = ops.internal_convert_to_tensor(var, as_ref: true); + } + + if (var.op.type == "ReadVariableOp") + name = var.op.inputs[0].op.Name; + else + name = var.op.Name; + + if (names_to_saveables.ContainsKey(name)) + throw new ValueError($"At least two variables have the same name: {name}"); + + names_to_saveables[name] = tensor; + } + } + } + + return names_to_saveables; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index 00fe846b..ba4fbea8 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -8,10 +8,9 @@ namespace Tensorflow { public static class train { - public static Optimizer GradientDescentOptimizer(double learning_rate) - { - return new GradientDescentOptimizer(learning_rate); - } + public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); + + public static Saver Saver() => new Saver(); } } } diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 9cefa6fd..5cde1359 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -16,6 +16,26 @@ namespace Tensorflow return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); } + /// + /// Returns all variables and `SaveableObject`s that must be checkpointed. + /// + /// + /// + public static RefVariable[] _all_saveable_objects(string scope = "") + { + var all = new List(); + + var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); + if(collection != null) + all.AddRange(collection as List); + + collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); + if (collection != null) + all.AddRange(collection as List); + + return all.ToArray(); + } + /// /// Returns global variables. /// diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index cfc74aff..78e25bd8 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -27,6 +27,11 @@ namespace Tensorflow /// Default collection for all variables, except local ones. /// public static string GLOBAL_VARIABLES = "variables"; + + /// + /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. + /// + public static string SAVEABLE_OBJECTS = "saveable_objects"; } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 74c8a5f7..1fa09224 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -387,6 +387,10 @@ namespace Tensorflow { case "Tensor": return value as Tensor; + case "String": + return constant_op.constant(Convert.ToString(value), name); + case "String[]": + return constant_op.constant(value as string[], name); case "Int32": return constant_op.constant(Convert.ToInt32(value), name); case "Double": diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 99e7ee20..40775fcb 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -7,7 +7,7 @@ using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] - public class TrainSaverTest + public class TrainSaverTest : Python { [TestMethod] public void Save() @@ -20,6 +20,14 @@ namespace TensorFlowNET.UnitTest // Add an op to initialize the variables. var init_op = tf.global_variables_initializer(); + + // Add ops to save and restore all the variables. + var saver = tf.train.Saver(); + + with(tf.Session(), sess => + { + sess.run(init_op); + }); } } }