@@ -2,4 +2,11 @@ | |||
### Saver | |||
The `tf.train.saver` class provides methods to save and restore models. | |||
The `tf.train.saver` class provides methods to save and restore models. | |||
### Saver Builder | |||
##### Bulk Saver Builder | |||
@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public interface IPyClass | |||
{ | |||
/// <summary> | |||
/// Called when the instance is created. | |||
/// </summary> | |||
/// <param name="args"></param> | |||
void __init__(IPyClass self, dynamic args); | |||
void __enter__(IPyClass self); | |||
void __exit__(IPyClass self); | |||
void __del__(IPyClass self); | |||
} | |||
} |
@@ -116,6 +116,10 @@ namespace Tensorflow | |||
values = new Tensor[] { keywords[input_name] as Tensor }; | |||
} | |||
inputs.AddRange(values as Tensor[]); | |||
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||
input_types.AddRange(base_types); | |||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||
{ | |||
if (attrs.ContainsKey(input_arg.NumberAttr)) | |||
@@ -144,10 +148,32 @@ namespace Tensorflow | |||
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); | |||
} | |||
} | |||
else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) | |||
{ | |||
var attr_value = base_types[0]; | |||
if (attrs.ContainsKey(input_arg.TypeAttr)) | |||
{ | |||
inputs.AddRange(values as Tensor[]); | |||
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||
input_types.AddRange(base_types); | |||
} | |||
else | |||
{ | |||
attrs[input_arg.TypeAttr] = attr_value; | |||
inferred_from[input_arg.TypeAttr] = input_name; | |||
} | |||
} | |||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||
{ | |||
var attr_value = base_types; | |||
if (attrs.ContainsKey(input_arg.TypeListAttr)) | |||
{ | |||
} | |||
else | |||
{ | |||
attrs[input_arg.TypeListAttr] = attr_value; | |||
inferred_from[input_arg.TypeListAttr] = input_name; | |||
} | |||
} | |||
} | |||
// Process remaining attrs | |||
@@ -213,6 +239,11 @@ namespace Tensorflow | |||
case "type": | |||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||
break; | |||
case "list(type)": | |||
if (attr_value.List == null) | |||
attr_value.List = new AttrValue.Types.ListValue(); | |||
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | |||
break; | |||
case "bool": | |||
attr_value.B = (bool)value; | |||
break; | |||
@@ -225,9 +256,14 @@ namespace Tensorflow | |||
throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); | |||
break; | |||
case "shape": | |||
attr_value.Shape = value == null ? | |||
attr_def.DefaultValue.Shape : | |||
tensor_util.as_shape((long[])value); | |||
if (value == null && attr_def.DefaultValue != null) | |||
attr_value.Shape = attr_def.DefaultValue.Shape; | |||
if(value is TensorShape val1) | |||
attr_value.Shape = val1.as_proto(); | |||
else if(value is long[] val2) | |||
attr_value.Shape = tensor_util.as_shape(val2); | |||
break; | |||
default: | |||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
@@ -106,6 +106,19 @@ namespace Tensorflow | |||
throw new NotImplementedException("where"); | |||
} | |||
/// <summary> | |||
/// A placeholder op that passes through `input` when its output is not fed. | |||
/// </summary> | |||
/// <param name="input">The default value to produce when output is not fed.</param> | |||
/// <param name="shape"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | |||
@@ -0,0 +1,18 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class gen_io_ops | |||
{ | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = "") | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | |||
return _op; | |||
} | |||
} | |||
} |
@@ -0,0 +1,401 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: saver.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Tensorflow { | |||
/// <summary>Holder for reflection information generated from saver.proto</summary> | |||
public static partial class SaverReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for saver.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static SaverReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"CgtzYXZlci5wcm90bxIKdGVuc29yZmxvdyKeAgoIU2F2ZXJEZWYSHAoUZmls", | |||
"ZW5hbWVfdGVuc29yX25hbWUYASABKAkSGAoQc2F2ZV90ZW5zb3JfbmFtZRgC", | |||
"IAEoCRIXCg9yZXN0b3JlX29wX25hbWUYAyABKAkSEwoLbWF4X3RvX2tlZXAY", | |||
"BCABKAUSDwoHc2hhcmRlZBgFIAEoCBIlCh1rZWVwX2NoZWNrcG9pbnRfZXZl", | |||
"cnlfbl9ob3VycxgGIAEoAhI9Cgd2ZXJzaW9uGAcgASgOMiwudGVuc29yZmxv", | |||
"dy5TYXZlckRlZi5DaGVja3BvaW50Rm9ybWF0VmVyc2lvbiI1ChdDaGVja3Bv", | |||
"aW50Rm9ybWF0VmVyc2lvbhIKCgZMRUdBQ1kQABIGCgJWMRABEgYKAlYyEAJC", | |||
"ZQoTb3JnLnRlbnNvcmZsb3cudXRpbEILU2F2ZXJQcm90b3NQAVo8Z2l0aHVi", | |||
"LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl", | |||
"L3Byb3RvYnVm+AEBYgZwcm90bzM=")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { }, | |||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
/// <summary> | |||
/// Protocol buffer representing the configuration of a Saver. | |||
/// </summary> | |||
public sealed partial class SaverDef : pb::IMessage<SaverDef> { | |||
private static readonly pb::MessageParser<SaverDef> _parser = new pb::MessageParser<SaverDef>(() => new SaverDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<SaverDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public SaverDef() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public SaverDef(SaverDef other) : this() { | |||
filenameTensorName_ = other.filenameTensorName_; | |||
saveTensorName_ = other.saveTensorName_; | |||
restoreOpName_ = other.restoreOpName_; | |||
maxToKeep_ = other.maxToKeep_; | |||
sharded_ = other.sharded_; | |||
keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_; | |||
version_ = other.version_; | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public SaverDef Clone() { | |||
return new SaverDef(this); | |||
} | |||
/// <summary>Field number for the "filename_tensor_name" field.</summary> | |||
public const int FilenameTensorNameFieldNumber = 1; | |||
private string filenameTensorName_ = ""; | |||
/// <summary> | |||
/// The name of the tensor in which to specify the filename when saving or | |||
/// restoring a model checkpoint. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public string FilenameTensorName { | |||
get { return filenameTensorName_; } | |||
set { | |||
filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "save_tensor_name" field.</summary> | |||
public const int SaveTensorNameFieldNumber = 2; | |||
private string saveTensorName_ = ""; | |||
/// <summary> | |||
/// The operation to run when saving a model checkpoint. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public string SaveTensorName { | |||
get { return saveTensorName_; } | |||
set { | |||
saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "restore_op_name" field.</summary> | |||
public const int RestoreOpNameFieldNumber = 3; | |||
private string restoreOpName_ = ""; | |||
/// <summary> | |||
/// The operation to run when restoring a model checkpoint. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public string RestoreOpName { | |||
get { return restoreOpName_; } | |||
set { | |||
restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "max_to_keep" field.</summary> | |||
public const int MaxToKeepFieldNumber = 4; | |||
private int maxToKeep_; | |||
/// <summary> | |||
/// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int MaxToKeep { | |||
get { return maxToKeep_; } | |||
set { | |||
maxToKeep_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "sharded" field.</summary> | |||
public const int ShardedFieldNumber = 5; | |||
private bool sharded_; | |||
/// <summary> | |||
/// Shard the save files, one per device that has Variable nodes. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Sharded { | |||
get { return sharded_; } | |||
set { | |||
sharded_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "keep_checkpoint_every_n_hours" field.</summary> | |||
public const int KeepCheckpointEveryNHoursFieldNumber = 6; | |||
private float keepCheckpointEveryNHours_; | |||
/// <summary> | |||
/// How often to keep an additional checkpoint. If not specified, only the last | |||
/// "max_to_keep" checkpoints are kept; if specified, in addition to keeping | |||
/// the last "max_to_keep" checkpoints, an additional checkpoint will be kept | |||
/// for every n hours of training. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public float KeepCheckpointEveryNHours { | |||
get { return keepCheckpointEveryNHours_; } | |||
set { | |||
keepCheckpointEveryNHours_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "version" field.</summary> | |||
public const int VersionFieldNumber = 7; | |||
private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = 0; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version { | |||
get { return version_; } | |||
set { | |||
version_ = value; | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as SaverDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(SaverDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (FilenameTensorName != other.FilenameTensorName) return false; | |||
if (SaveTensorName != other.SaveTensorName) return false; | |||
if (RestoreOpName != other.RestoreOpName) return false; | |||
if (MaxToKeep != other.MaxToKeep) return false; | |||
if (Sharded != other.Sharded) return false; | |||
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false; | |||
if (Version != other.Version) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode(); | |||
if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode(); | |||
if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode(); | |||
if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode(); | |||
if (Sharded != false) hash ^= Sharded.GetHashCode(); | |||
if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours); | |||
if (Version != 0) hash ^= Version.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
if (FilenameTensorName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(FilenameTensorName); | |||
} | |||
if (SaveTensorName.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(SaveTensorName); | |||
} | |||
if (RestoreOpName.Length != 0) { | |||
output.WriteRawTag(26); | |||
output.WriteString(RestoreOpName); | |||
} | |||
if (MaxToKeep != 0) { | |||
output.WriteRawTag(32); | |||
output.WriteInt32(MaxToKeep); | |||
} | |||
if (Sharded != false) { | |||
output.WriteRawTag(40); | |||
output.WriteBool(Sharded); | |||
} | |||
if (KeepCheckpointEveryNHours != 0F) { | |||
output.WriteRawTag(53); | |||
output.WriteFloat(KeepCheckpointEveryNHours); | |||
} | |||
if (Version != 0) { | |||
output.WriteRawTag(56); | |||
output.WriteEnum((int) Version); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (FilenameTensorName.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName); | |||
} | |||
if (SaveTensorName.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName); | |||
} | |||
if (RestoreOpName.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName); | |||
} | |||
if (MaxToKeep != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep); | |||
} | |||
if (Sharded != false) { | |||
size += 1 + 1; | |||
} | |||
if (KeepCheckpointEveryNHours != 0F) { | |||
size += 1 + 4; | |||
} | |||
if (Version != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version); | |||
} | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(SaverDef other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.FilenameTensorName.Length != 0) { | |||
FilenameTensorName = other.FilenameTensorName; | |||
} | |||
if (other.SaveTensorName.Length != 0) { | |||
SaveTensorName = other.SaveTensorName; | |||
} | |||
if (other.RestoreOpName.Length != 0) { | |||
RestoreOpName = other.RestoreOpName; | |||
} | |||
if (other.MaxToKeep != 0) { | |||
MaxToKeep = other.MaxToKeep; | |||
} | |||
if (other.Sharded != false) { | |||
Sharded = other.Sharded; | |||
} | |||
if (other.KeepCheckpointEveryNHours != 0F) { | |||
KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours; | |||
} | |||
if (other.Version != 0) { | |||
Version = other.Version; | |||
} | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
FilenameTensorName = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
SaveTensorName = input.ReadString(); | |||
break; | |||
} | |||
case 26: { | |||
RestoreOpName = input.ReadString(); | |||
break; | |||
} | |||
case 32: { | |||
MaxToKeep = input.ReadInt32(); | |||
break; | |||
} | |||
case 40: { | |||
Sharded = input.ReadBool(); | |||
break; | |||
} | |||
case 53: { | |||
KeepCheckpointEveryNHours = input.ReadFloat(); | |||
break; | |||
} | |||
case 56: { | |||
version_ = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
#region Nested types | |||
/// <summary>Container for nested types declared in the SaverDef message type.</summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static partial class Types { | |||
/// <summary> | |||
/// A version number that identifies a different on-disk checkpoint format. | |||
/// Usually, each subclass of BaseSaverBuilder works with a particular | |||
/// version/format. However, it is possible that the same builder may be | |||
/// upgraded to support a newer checkpoint format in the future. | |||
/// </summary> | |||
public enum CheckpointFormatVersion { | |||
/// <summary> | |||
/// Internal legacy format. | |||
/// </summary> | |||
[pbr::OriginalName("LEGACY")] Legacy = 0, | |||
/// <summary> | |||
/// Deprecated format: tf.Saver() which works with tensorflow::table::Table. | |||
/// </summary> | |||
[pbr::OriginalName("V1")] V1 = 1, | |||
/// <summary> | |||
/// Current format: more efficient. | |||
/// </summary> | |||
[pbr::OriginalName("V2")] V2 = 2, | |||
} | |||
} | |||
#endregion | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |
@@ -15,6 +15,15 @@ namespace Tensorflow | |||
Console.WriteLine(obj.ToString()); | |||
} | |||
public static T New<T>(object args) where T : IPyClass | |||
{ | |||
var instance = Activator.CreateInstance<T>(); | |||
instance.__init__(instance, args); | |||
return instance; | |||
} | |||
public static void with(IPython py, Action<IPython> action) | |||
{ | |||
try | |||
@@ -63,7 +72,7 @@ namespace Tensorflow | |||
catch (Exception ex) | |||
{ | |||
Console.WriteLine(ex.ToString()); | |||
throw ex; | |||
return default(TOut); | |||
} | |||
finally | |||
{ | |||
@@ -97,4 +106,9 @@ namespace Tensorflow | |||
void __exit__(); | |||
} | |||
public class PyObject<T> where T : IPyClass | |||
{ | |||
public T Instance { get; set; } | |||
} | |||
} |
@@ -0,0 +1,98 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class BaseSaverBuilder | |||
{ | |||
protected int _write_version; | |||
public BaseSaverBuilder(int write_version = 2) | |||
{ | |||
_write_version = write_version; | |||
} | |||
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | |||
{ | |||
var tensor_names = new List<string>(); | |||
var tensors = new List<Tensor>(); | |||
var tensor_slices = new List<string>(); | |||
foreach (var saveable in saveables) | |||
{ | |||
foreach(var spec in saveable.specs) | |||
{ | |||
tensor_names.Add(spec.name); | |||
tensors.Add(spec.tensor); | |||
tensor_slices.Add(spec.slice_spec); | |||
} | |||
} | |||
if (_write_version == 2) | |||
{ | |||
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("_write_version v1"); | |||
} | |||
} | |||
public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
double keep_checkpoint_every_n_hours = 10000, | |||
string name = "", | |||
bool restore_sequentially = false, | |||
string filename = "model", | |||
bool build_save = true, | |||
bool build_restore = true) | |||
{ | |||
if (!build_save || !build_restore) | |||
throw new ValueError("save and restore operations need to be built together " + | |||
" when eager execution is not enabled."); | |||
var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables); | |||
if (max_to_keep < 0) | |||
max_to_keep = 0; | |||
Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||
{ | |||
name = scope; | |||
// Add a placeholder string tensor for the filename. | |||
var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); | |||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); | |||
// Keep the name "Const" for backwards compatibility. | |||
// Add the save ops. | |||
if (sharded) | |||
{ | |||
} | |||
else | |||
{ | |||
if (build_save) | |||
_AddSaveOps(filename_tensor, saveables); | |||
} | |||
}); | |||
throw new NotImplementedException(""); | |||
} | |||
public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||
{ | |||
var save = save_op(filename_tensor, saveables); | |||
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | |||
} | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | |||
{ | |||
public BulkSaverBuilder(int write_version = 2) : base(write_version) | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,24 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public interface ISaverBuilder | |||
{ | |||
Operation save_op(Tensor filename_tensor, SaveableObject[] saveables); | |||
Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially); | |||
SaverDef _build_internal(RefVariable[] names_to_saveables, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
double keep_checkpoint_every_n_hours = 10000, | |||
string name = "", | |||
bool restore_sequentially = false, | |||
string filename = "model", | |||
bool build_save = true, | |||
bool build_restore = true); | |||
} | |||
} |
@@ -0,0 +1,19 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class ReferenceVariableSaveable : SaveableObject | |||
{ | |||
private SaveSpec _spec; | |||
public ReferenceVariableSaveable(Tensor var, string slice_spec, string name) | |||
{ | |||
_spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype); | |||
op = var; | |||
specs = new SaveSpec[] { _spec }; | |||
this.name = name; | |||
} | |||
} | |||
} |
@@ -0,0 +1,32 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Class used to describe tensor slices that need to be saved. | |||
/// </summary> | |||
public class SaveSpec | |||
{ | |||
private Tensor _tensor; | |||
public Tensor tensor => _tensor; | |||
private string _slice_spec; | |||
public string slice_spec => _slice_spec; | |||
private string _name; | |||
public string name => _name; | |||
private TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
_tensor = tensor; | |||
_slice_spec = slice_spec; | |||
_name = name; | |||
_dtype = dtype; | |||
} | |||
} | |||
} |
@@ -0,0 +1,31 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class SaveableObject | |||
{ | |||
public Tensor op; | |||
public SaveSpec[] specs; | |||
public string name; | |||
public string device; | |||
public SaveableObject() | |||
{ | |||
} | |||
public SaveableObject(Tensor var, string slice_spec, string name) | |||
{ | |||
} | |||
public SaveableObject(Tensor op, SaveSpec[] specs, string name) | |||
{ | |||
this.op = op; | |||
this.specs = specs; | |||
this.name = name; | |||
} | |||
} | |||
} |
@@ -0,0 +1,113 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Saves and restores variables. | |||
/// </summary> | |||
public class Saver | |||
{ | |||
private RefVariable[] _var_list; | |||
private bool _reshape; | |||
private bool _sharded; | |||
private int _max_to_keep; | |||
private double _keep_checkpoint_every_n_hours; | |||
private string _name; | |||
private bool _restore_sequentially; | |||
private SaverDef _saver_def; | |||
private ISaverBuilder _builder; | |||
private bool _allow_empty; | |||
private bool _is_built; | |||
private int _write_version; | |||
private bool _pad_step_number; | |||
private string _filename; | |||
private bool _is_empty; | |||
public Saver(RefVariable[] var_list = null, | |||
bool reshape = false, | |||
bool sharded = false, | |||
int max_to_keep = 5, | |||
double keep_checkpoint_every_n_hours = 10000, | |||
string name = "", | |||
bool restore_sequentially = false, | |||
SaverDef saver_def = null, | |||
ISaverBuilder builder = null, | |||
bool defer_build = false, | |||
bool allow_empty = false, | |||
int write_version = 2, | |||
bool pad_step_number = false, | |||
bool save_relative_paths = false, | |||
string filename = "") | |||
{ | |||
_var_list = var_list; | |||
_reshape = reshape; | |||
_sharded = sharded; | |||
_max_to_keep = max_to_keep; | |||
_keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours; | |||
_name = name; | |||
_restore_sequentially = restore_sequentially; | |||
_builder = builder; | |||
_is_built = false; | |||
_allow_empty = allow_empty; | |||
_write_version = write_version; | |||
_pad_step_number = pad_step_number; | |||
if (!defer_build) | |||
build(); | |||
} | |||
public void build() | |||
{ | |||
_build(_filename, build_save: true, build_restore: true); | |||
} | |||
private void _build(string checkpoint_path, bool build_save, bool build_restore) | |||
{ | |||
if (_is_built) return; | |||
_is_built = true; | |||
if (_saver_def == null) | |||
{ | |||
if (_builder == null) | |||
_builder = new BulkSaverBuilder(_write_version); | |||
if (_var_list == null) | |||
_var_list = variables._all_saveable_objects(); | |||
if (_var_list == null || _var_list.Length == 0) | |||
{ | |||
if (_allow_empty) | |||
{ | |||
_is_empty = true; | |||
return; | |||
} | |||
else | |||
{ | |||
throw new ValueError("No variables to save"); | |||
} | |||
} | |||
_is_empty = false; | |||
_saver_def = _builder._build_internal(_var_list, | |||
reshape: _reshape, | |||
sharded: _sharded, | |||
max_to_keep: _max_to_keep, | |||
keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours, | |||
name: _name, | |||
restore_sequentially: _restore_sequentially, | |||
filename: checkpoint_path, | |||
build_save: build_save, | |||
build_restore: build_restore); | |||
} | |||
else if (_saver_def != null && !string.IsNullOrEmpty(_name)) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,102 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class saveable_object_util | |||
{ | |||
/// <summary> | |||
/// Returns the variables and names that will be used for a Saver. | |||
/// </summary> | |||
/// <param name="names_to_saveables"></param> | |||
/// <returns></returns> | |||
public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables) | |||
{ | |||
var names_to_saveables_dict = op_list_to_dict(names_to_saveables); | |||
var saveables = new List<SaveableObject>(); | |||
var seen_ops = new List<Tensor>(); | |||
foreach (var item in names_to_saveables_dict) | |||
{ | |||
foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||
_add_saveable(saveables, seen_ops, converted_saveable_object); | |||
} | |||
return saveables.ToArray(); | |||
} | |||
private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : SaveableObject | |||
{ | |||
if (seen_ops.Contains(saveable.op)) | |||
throw new ValueError($"The same saveable will be restored with two names: {saveable.name}"); | |||
saveables.Add(saveable); | |||
seen_ops.Add(saveable.op); | |||
} | |||
/// <summary> | |||
/// Create `SaveableObject`s from an operation. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static IEnumerable<SaveableObject> saveable_objects_for_op(Tensor op, string name) | |||
{ | |||
if (false) | |||
{ | |||
} | |||
else | |||
{ | |||
ops.init_scope(); | |||
var variable = ops.internal_convert_to_tensor(op, as_ref: true); | |||
if (variable.op.type == "VariableV2") | |||
yield return new ReferenceVariableSaveable(variable, "", name); | |||
} | |||
} | |||
public static Dictionary<string, Tensor> op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true) | |||
{ | |||
op_list = op_list.OrderBy(x => x.name).ToArray(); | |||
var names_to_saveables = new Dictionary<string, Tensor>(); | |||
foreach(var var in op_list) | |||
{ | |||
if (false) | |||
{ | |||
throw new NotImplementedException("op_list_to_dict"); | |||
} | |||
else | |||
{ | |||
if(false) // eager | |||
{ | |||
} | |||
else | |||
{ | |||
string name = ""; | |||
Tensor tensor = null; | |||
if (convert_variable_to_tensor) | |||
{ | |||
tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||
} | |||
if (var.op.type == "ReadVariableOp") | |||
name = var.op.inputs[0].op.Name; | |||
else | |||
name = var.op.Name; | |||
if (names_to_saveables.ContainsKey(name)) | |||
throw new ValueError($"At least two variables have the same name: {name}"); | |||
names_to_saveables[name] = tensor; | |||
} | |||
} | |||
} | |||
return names_to_saveables; | |||
} | |||
} | |||
} |
@@ -8,10 +8,9 @@ namespace Tensorflow | |||
{ | |||
public static class train | |||
{ | |||
public static Optimizer GradientDescentOptimizer(double learning_rate) | |||
{ | |||
return new GradientDescentOptimizer(learning_rate); | |||
} | |||
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | |||
public static Saver Saver() => new Saver(); | |||
} | |||
} | |||
} |
@@ -16,6 +16,26 @@ namespace Tensorflow | |||
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
} | |||
/// <summary> | |||
/// Returns all variables and `SaveableObject`s that must be checkpointed. | |||
/// </summary> | |||
/// <param name="scope"></param> | |||
/// <returns></returns> | |||
public static RefVariable[] _all_saveable_objects(string scope = "") | |||
{ | |||
var all = new List<RefVariable>(); | |||
var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||
if(collection != null) | |||
all.AddRange(collection as List<RefVariable>); | |||
collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||
if (collection != null) | |||
all.AddRange(collection as List<RefVariable>); | |||
return all.ToArray(); | |||
} | |||
/// <summary> | |||
/// Returns global variables. | |||
/// </summary> | |||
@@ -27,6 +27,11 @@ namespace Tensorflow | |||
/// Default collection for all variables, except local ones. | |||
/// </summary> | |||
public static string GLOBAL_VARIABLES = "variables"; | |||
/// <summary> | |||
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
/// </summary> | |||
public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||
} | |||
} | |||
} |
@@ -387,6 +387,10 @@ namespace Tensorflow | |||
{ | |||
case "Tensor": | |||
return value as Tensor; | |||
case "String": | |||
return constant_op.constant(Convert.ToString(value), name); | |||
case "String[]": | |||
return constant_op.constant(value as string[], name); | |||
case "Int32": | |||
return constant_op.constant(Convert.ToInt32(value), name); | |||
case "Double": | |||
@@ -7,7 +7,7 @@ using Tensorflow; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class TrainSaverTest | |||
public class TrainSaverTest : Python | |||
{ | |||
[TestMethod] | |||
public void Save() | |||
@@ -20,6 +20,14 @@ namespace TensorFlowNET.UnitTest | |||
// Add an op to initialize the variables. | |||
var init_op = tf.global_variables_initializer(); | |||
// Add ops to save and restore all the variables. | |||
var saver = tf.train.Saver(); | |||
with<Session>(tf.Session(), sess => | |||
{ | |||
sess.run(init_op); | |||
}); | |||
} | |||
} | |||
} |