@@ -2,4 +2,11 @@ | |||||
### Saver | ### Saver | ||||
The `tf.train.saver` class provides methods to save and restore models. | |||||
The `tf.train.saver` class provides methods to save and restore models. | |||||
### Saver Builder | |||||
##### Bulk Saver Builder | |||||
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public interface IPyClass | |||||
{ | |||||
/// <summary> | |||||
/// Called when the instance is created. | |||||
/// </summary> | |||||
/// <param name="args"></param> | |||||
void __init__(IPyClass self, dynamic args); | |||||
void __enter__(IPyClass self); | |||||
void __exit__(IPyClass self); | |||||
void __del__(IPyClass self); | |||||
} | |||||
} |
@@ -116,6 +116,10 @@ namespace Tensorflow | |||||
values = new Tensor[] { keywords[input_name] as Tensor }; | values = new Tensor[] { keywords[input_name] as Tensor }; | ||||
} | } | ||||
inputs.AddRange(values as Tensor[]); | |||||
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||||
input_types.AddRange(base_types); | |||||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | ||||
{ | { | ||||
if (attrs.ContainsKey(input_arg.NumberAttr)) | if (attrs.ContainsKey(input_arg.NumberAttr)) | ||||
@@ -144,10 +148,32 @@ namespace Tensorflow | |||||
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); | var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); | ||||
} | } | ||||
} | } | ||||
else if (!string.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
{ | |||||
var attr_value = base_types[0]; | |||||
if (attrs.ContainsKey(input_arg.TypeAttr)) | |||||
{ | |||||
inputs.AddRange(values as Tensor[]); | |||||
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype())); | |||||
input_types.AddRange(base_types); | |||||
} | |||||
else | |||||
{ | |||||
attrs[input_arg.TypeAttr] = attr_value; | |||||
inferred_from[input_arg.TypeAttr] = input_name; | |||||
} | |||||
} | |||||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
{ | |||||
var attr_value = base_types; | |||||
if (attrs.ContainsKey(input_arg.TypeListAttr)) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
attrs[input_arg.TypeListAttr] = attr_value; | |||||
inferred_from[input_arg.TypeListAttr] = input_name; | |||||
} | |||||
} | |||||
} | } | ||||
// Process remaining attrs | // Process remaining attrs | ||||
@@ -213,6 +239,11 @@ namespace Tensorflow | |||||
case "type": | case "type": | ||||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | attr_value.Type = _MakeType((TF_DataType)value, attr_def); | ||||
break; | break; | ||||
case "list(type)": | |||||
if (attr_value.List == null) | |||||
attr_value.List = new AttrValue.Types.ListValue(); | |||||
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | |||||
break; | |||||
case "bool": | case "bool": | ||||
attr_value.B = (bool)value; | attr_value.B = (bool)value; | ||||
break; | break; | ||||
@@ -225,9 +256,14 @@ namespace Tensorflow | |||||
throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); | throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); | ||||
break; | break; | ||||
case "shape": | case "shape": | ||||
attr_value.Shape = value == null ? | |||||
attr_def.DefaultValue.Shape : | |||||
tensor_util.as_shape((long[])value); | |||||
if (value == null && attr_def.DefaultValue != null) | |||||
attr_value.Shape = attr_def.DefaultValue.Shape; | |||||
if(value is TensorShape val1) | |||||
attr_value.Shape = val1.as_proto(); | |||||
else if(value is long[] val2) | |||||
attr_value.Shape = tensor_util.as_shape(val2); | |||||
break; | break; | ||||
default: | default: | ||||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | ||||
@@ -106,6 +106,19 @@ namespace Tensorflow | |||||
throw new NotImplementedException("where"); | throw new NotImplementedException("where"); | ||||
} | } | ||||
/// <summary> | |||||
/// A placeholder op that passes through `input` when its output is not fed. | |||||
/// </summary> | |||||
/// <param name="input">The default value to produce when output is not fed.</param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "") | public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "") | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | ||||
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class gen_io_ops | |||||
{ | |||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | |||||
return _op; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,401 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: saver.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Tensorflow { | |||||
/// <summary>Holder for reflection information generated from saver.proto</summary> | |||||
public static partial class SaverReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for saver.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static SaverReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"CgtzYXZlci5wcm90bxIKdGVuc29yZmxvdyKeAgoIU2F2ZXJEZWYSHAoUZmls", | |||||
"ZW5hbWVfdGVuc29yX25hbWUYASABKAkSGAoQc2F2ZV90ZW5zb3JfbmFtZRgC", | |||||
"IAEoCRIXCg9yZXN0b3JlX29wX25hbWUYAyABKAkSEwoLbWF4X3RvX2tlZXAY", | |||||
"BCABKAUSDwoHc2hhcmRlZBgFIAEoCBIlCh1rZWVwX2NoZWNrcG9pbnRfZXZl", | |||||
"cnlfbl9ob3VycxgGIAEoAhI9Cgd2ZXJzaW9uGAcgASgOMiwudGVuc29yZmxv", | |||||
"dy5TYXZlckRlZi5DaGVja3BvaW50Rm9ybWF0VmVyc2lvbiI1ChdDaGVja3Bv", | |||||
"aW50Rm9ybWF0VmVyc2lvbhIKCgZMRUdBQ1kQABIGCgJWMRABEgYKAlYyEAJC", | |||||
"ZQoTb3JnLnRlbnNvcmZsb3cudXRpbEILU2F2ZXJQcm90b3NQAVo8Z2l0aHVi", | |||||
"LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl", | |||||
"L3Byb3RvYnVm+AEBYgZwcm90bzM=")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { }, | |||||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
/// <summary> | |||||
/// Protocol buffer representing the configuration of a Saver. | |||||
/// </summary> | |||||
public sealed partial class SaverDef : pb::IMessage<SaverDef> { | |||||
private static readonly pb::MessageParser<SaverDef> _parser = new pb::MessageParser<SaverDef>(() => new SaverDef()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<SaverDef> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public SaverDef() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public SaverDef(SaverDef other) : this() { | |||||
filenameTensorName_ = other.filenameTensorName_; | |||||
saveTensorName_ = other.saveTensorName_; | |||||
restoreOpName_ = other.restoreOpName_; | |||||
maxToKeep_ = other.maxToKeep_; | |||||
sharded_ = other.sharded_; | |||||
keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_; | |||||
version_ = other.version_; | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public SaverDef Clone() { | |||||
return new SaverDef(this); | |||||
} | |||||
/// <summary>Field number for the "filename_tensor_name" field.</summary> | |||||
public const int FilenameTensorNameFieldNumber = 1; | |||||
private string filenameTensorName_ = ""; | |||||
/// <summary> | |||||
/// The name of the tensor in which to specify the filename when saving or | |||||
/// restoring a model checkpoint. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public string FilenameTensorName { | |||||
get { return filenameTensorName_; } | |||||
set { | |||||
filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "save_tensor_name" field.</summary> | |||||
public const int SaveTensorNameFieldNumber = 2; | |||||
private string saveTensorName_ = ""; | |||||
/// <summary> | |||||
/// The operation to run when saving a model checkpoint. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public string SaveTensorName { | |||||
get { return saveTensorName_; } | |||||
set { | |||||
saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "restore_op_name" field.</summary> | |||||
public const int RestoreOpNameFieldNumber = 3; | |||||
private string restoreOpName_ = ""; | |||||
/// <summary> | |||||
/// The operation to run when restoring a model checkpoint. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public string RestoreOpName { | |||||
get { return restoreOpName_; } | |||||
set { | |||||
restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "max_to_keep" field.</summary> | |||||
public const int MaxToKeepFieldNumber = 4; | |||||
private int maxToKeep_; | |||||
/// <summary> | |||||
/// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int MaxToKeep { | |||||
get { return maxToKeep_; } | |||||
set { | |||||
maxToKeep_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "sharded" field.</summary> | |||||
public const int ShardedFieldNumber = 5; | |||||
private bool sharded_; | |||||
/// <summary> | |||||
/// Shard the save files, one per device that has Variable nodes. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Sharded { | |||||
get { return sharded_; } | |||||
set { | |||||
sharded_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "keep_checkpoint_every_n_hours" field.</summary> | |||||
public const int KeepCheckpointEveryNHoursFieldNumber = 6; | |||||
private float keepCheckpointEveryNHours_; | |||||
/// <summary> | |||||
/// How often to keep an additional checkpoint. If not specified, only the last | |||||
/// "max_to_keep" checkpoints are kept; if specified, in addition to keeping | |||||
/// the last "max_to_keep" checkpoints, an additional checkpoint will be kept | |||||
/// for every n hours of training. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public float KeepCheckpointEveryNHours { | |||||
get { return keepCheckpointEveryNHours_; } | |||||
set { | |||||
keepCheckpointEveryNHours_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "version" field.</summary> | |||||
public const int VersionFieldNumber = 7; | |||||
private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = 0; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version { | |||||
get { return version_; } | |||||
set { | |||||
version_ = value; | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as SaverDef); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(SaverDef other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (FilenameTensorName != other.FilenameTensorName) return false; | |||||
if (SaveTensorName != other.SaveTensorName) return false; | |||||
if (RestoreOpName != other.RestoreOpName) return false; | |||||
if (MaxToKeep != other.MaxToKeep) return false; | |||||
if (Sharded != other.Sharded) return false; | |||||
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false; | |||||
if (Version != other.Version) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode(); | |||||
if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode(); | |||||
if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode(); | |||||
if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode(); | |||||
if (Sharded != false) hash ^= Sharded.GetHashCode(); | |||||
if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours); | |||||
if (Version != 0) hash ^= Version.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
if (FilenameTensorName.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(FilenameTensorName); | |||||
} | |||||
if (SaveTensorName.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(SaveTensorName); | |||||
} | |||||
if (RestoreOpName.Length != 0) { | |||||
output.WriteRawTag(26); | |||||
output.WriteString(RestoreOpName); | |||||
} | |||||
if (MaxToKeep != 0) { | |||||
output.WriteRawTag(32); | |||||
output.WriteInt32(MaxToKeep); | |||||
} | |||||
if (Sharded != false) { | |||||
output.WriteRawTag(40); | |||||
output.WriteBool(Sharded); | |||||
} | |||||
if (KeepCheckpointEveryNHours != 0F) { | |||||
output.WriteRawTag(53); | |||||
output.WriteFloat(KeepCheckpointEveryNHours); | |||||
} | |||||
if (Version != 0) { | |||||
output.WriteRawTag(56); | |||||
output.WriteEnum((int) Version); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (FilenameTensorName.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName); | |||||
} | |||||
if (SaveTensorName.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName); | |||||
} | |||||
if (RestoreOpName.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName); | |||||
} | |||||
if (MaxToKeep != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep); | |||||
} | |||||
if (Sharded != false) { | |||||
size += 1 + 1; | |||||
} | |||||
if (KeepCheckpointEveryNHours != 0F) { | |||||
size += 1 + 4; | |||||
} | |||||
if (Version != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version); | |||||
} | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(SaverDef other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.FilenameTensorName.Length != 0) { | |||||
FilenameTensorName = other.FilenameTensorName; | |||||
} | |||||
if (other.SaveTensorName.Length != 0) { | |||||
SaveTensorName = other.SaveTensorName; | |||||
} | |||||
if (other.RestoreOpName.Length != 0) { | |||||
RestoreOpName = other.RestoreOpName; | |||||
} | |||||
if (other.MaxToKeep != 0) { | |||||
MaxToKeep = other.MaxToKeep; | |||||
} | |||||
if (other.Sharded != false) { | |||||
Sharded = other.Sharded; | |||||
} | |||||
if (other.KeepCheckpointEveryNHours != 0F) { | |||||
KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours; | |||||
} | |||||
if (other.Version != 0) { | |||||
Version = other.Version; | |||||
} | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
FilenameTensorName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
SaveTensorName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 26: { | |||||
RestoreOpName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 32: { | |||||
MaxToKeep = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 40: { | |||||
Sharded = input.ReadBool(); | |||||
break; | |||||
} | |||||
case 53: { | |||||
KeepCheckpointEveryNHours = input.ReadFloat(); | |||||
break; | |||||
} | |||||
case 56: { | |||||
version_ = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#region Nested types | |||||
/// <summary>Container for nested types declared in the SaverDef message type.</summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static partial class Types { | |||||
/// <summary> | |||||
/// A version number that identifies a different on-disk checkpoint format. | |||||
/// Usually, each subclass of BaseSaverBuilder works with a particular | |||||
/// version/format. However, it is possible that the same builder may be | |||||
/// upgraded to support a newer checkpoint format in the future. | |||||
/// </summary> | |||||
public enum CheckpointFormatVersion { | |||||
/// <summary> | |||||
/// Internal legacy format. | |||||
/// </summary> | |||||
[pbr::OriginalName("LEGACY")] Legacy = 0, | |||||
/// <summary> | |||||
/// Deprecated format: tf.Saver() which works with tensorflow::table::Table. | |||||
/// </summary> | |||||
[pbr::OriginalName("V1")] V1 = 1, | |||||
/// <summary> | |||||
/// Current format: more efficient. | |||||
/// </summary> | |||||
[pbr::OriginalName("V2")] V2 = 2, | |||||
} | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |
@@ -15,6 +15,15 @@ namespace Tensorflow | |||||
Console.WriteLine(obj.ToString()); | Console.WriteLine(obj.ToString()); | ||||
} | } | ||||
public static T New<T>(object args) where T : IPyClass | |||||
{ | |||||
var instance = Activator.CreateInstance<T>(); | |||||
instance.__init__(instance, args); | |||||
return instance; | |||||
} | |||||
public static void with(IPython py, Action<IPython> action) | public static void with(IPython py, Action<IPython> action) | ||||
{ | { | ||||
try | try | ||||
@@ -63,7 +72,7 @@ namespace Tensorflow | |||||
catch (Exception ex) | catch (Exception ex) | ||||
{ | { | ||||
Console.WriteLine(ex.ToString()); | Console.WriteLine(ex.ToString()); | ||||
throw ex; | |||||
return default(TOut); | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
@@ -97,4 +106,9 @@ namespace Tensorflow | |||||
void __exit__(); | void __exit__(); | ||||
} | } | ||||
public class PyObject<T> where T : IPyClass | |||||
{ | |||||
public T Instance { get; set; } | |||||
} | |||||
} | } |
@@ -0,0 +1,98 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class BaseSaverBuilder | |||||
{ | |||||
protected int _write_version; | |||||
public BaseSaverBuilder(int write_version = 2) | |||||
{ | |||||
_write_version = write_version; | |||||
} | |||||
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | |||||
{ | |||||
var tensor_names = new List<string>(); | |||||
var tensors = new List<Tensor>(); | |||||
var tensor_slices = new List<string>(); | |||||
foreach (var saveable in saveables) | |||||
{ | |||||
foreach(var spec in saveable.specs) | |||||
{ | |||||
tensor_names.Add(spec.name); | |||||
tensors.Add(spec.tensor); | |||||
tensor_slices.Add(spec.slice_spec); | |||||
} | |||||
} | |||||
if (_write_version == 2) | |||||
{ | |||||
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("_write_version v1"); | |||||
} | |||||
} | |||||
public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | |||||
bool reshape = false, | |||||
bool sharded = false, | |||||
int max_to_keep = 5, | |||||
double keep_checkpoint_every_n_hours = 10000, | |||||
string name = "", | |||||
bool restore_sequentially = false, | |||||
string filename = "model", | |||||
bool build_save = true, | |||||
bool build_restore = true) | |||||
{ | |||||
if (!build_save || !build_restore) | |||||
throw new ValueError("save and restore operations need to be built together " + | |||||
" when eager execution is not enabled."); | |||||
var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables); | |||||
if (max_to_keep < 0) | |||||
max_to_keep = 0; | |||||
Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||||
{ | |||||
name = scope; | |||||
// Add a placeholder string tensor for the filename. | |||||
var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename"); | |||||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const"); | |||||
// Keep the name "Const" for backwards compatibility. | |||||
// Add the save ops. | |||||
if (sharded) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
if (build_save) | |||||
_AddSaveOps(filename_tensor, saveables); | |||||
} | |||||
}); | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||||
{ | |||||
var save = save_op(filename_tensor, saveables); | |||||
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | |||||
{ | |||||
public BulkSaverBuilder(int write_version = 2) : base(write_version) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,24 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public interface ISaverBuilder | |||||
{ | |||||
Operation save_op(Tensor filename_tensor, SaveableObject[] saveables); | |||||
Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially); | |||||
SaverDef _build_internal(RefVariable[] names_to_saveables, | |||||
bool reshape = false, | |||||
bool sharded = false, | |||||
int max_to_keep = 5, | |||||
double keep_checkpoint_every_n_hours = 10000, | |||||
string name = "", | |||||
bool restore_sequentially = false, | |||||
string filename = "model", | |||||
bool build_save = true, | |||||
bool build_restore = true); | |||||
} | |||||
} |
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class ReferenceVariableSaveable : SaveableObject | |||||
{ | |||||
private SaveSpec _spec; | |||||
public ReferenceVariableSaveable(Tensor var, string slice_spec, string name) | |||||
{ | |||||
_spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype); | |||||
op = var; | |||||
specs = new SaveSpec[] { _spec }; | |||||
this.name = name; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,32 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Class used to describe tensor slices that need to be saved. | |||||
/// </summary> | |||||
public class SaveSpec | |||||
{ | |||||
private Tensor _tensor; | |||||
public Tensor tensor => _tensor; | |||||
private string _slice_spec; | |||||
public string slice_spec => _slice_spec; | |||||
private string _name; | |||||
public string name => _name; | |||||
private TF_DataType _dtype; | |||||
public TF_DataType dtype => _dtype; | |||||
public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
_tensor = tensor; | |||||
_slice_spec = slice_spec; | |||||
_name = name; | |||||
_dtype = dtype; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class SaveableObject | |||||
{ | |||||
public Tensor op; | |||||
public SaveSpec[] specs; | |||||
public string name; | |||||
public string device; | |||||
public SaveableObject() | |||||
{ | |||||
} | |||||
public SaveableObject(Tensor var, string slice_spec, string name) | |||||
{ | |||||
} | |||||
public SaveableObject(Tensor op, SaveSpec[] specs, string name) | |||||
{ | |||||
this.op = op; | |||||
this.specs = specs; | |||||
this.name = name; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,113 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Saves and restores variables. | |||||
/// </summary> | |||||
public class Saver | |||||
{ | |||||
private RefVariable[] _var_list; | |||||
private bool _reshape; | |||||
private bool _sharded; | |||||
private int _max_to_keep; | |||||
private double _keep_checkpoint_every_n_hours; | |||||
private string _name; | |||||
private bool _restore_sequentially; | |||||
private SaverDef _saver_def; | |||||
private ISaverBuilder _builder; | |||||
private bool _allow_empty; | |||||
private bool _is_built; | |||||
private int _write_version; | |||||
private bool _pad_step_number; | |||||
private string _filename; | |||||
private bool _is_empty; | |||||
public Saver(RefVariable[] var_list = null, | |||||
bool reshape = false, | |||||
bool sharded = false, | |||||
int max_to_keep = 5, | |||||
double keep_checkpoint_every_n_hours = 10000, | |||||
string name = "", | |||||
bool restore_sequentially = false, | |||||
SaverDef saver_def = null, | |||||
ISaverBuilder builder = null, | |||||
bool defer_build = false, | |||||
bool allow_empty = false, | |||||
int write_version = 2, | |||||
bool pad_step_number = false, | |||||
bool save_relative_paths = false, | |||||
string filename = "") | |||||
{ | |||||
_var_list = var_list; | |||||
_reshape = reshape; | |||||
_sharded = sharded; | |||||
_max_to_keep = max_to_keep; | |||||
_keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours; | |||||
_name = name; | |||||
_restore_sequentially = restore_sequentially; | |||||
_builder = builder; | |||||
_is_built = false; | |||||
_allow_empty = allow_empty; | |||||
_write_version = write_version; | |||||
_pad_step_number = pad_step_number; | |||||
if (!defer_build) | |||||
build(); | |||||
} | |||||
public void build() | |||||
{ | |||||
_build(_filename, build_save: true, build_restore: true); | |||||
} | |||||
private void _build(string checkpoint_path, bool build_save, bool build_restore) | |||||
{ | |||||
if (_is_built) return; | |||||
_is_built = true; | |||||
if (_saver_def == null) | |||||
{ | |||||
if (_builder == null) | |||||
_builder = new BulkSaverBuilder(_write_version); | |||||
if (_var_list == null) | |||||
_var_list = variables._all_saveable_objects(); | |||||
if (_var_list == null || _var_list.Length == 0) | |||||
{ | |||||
if (_allow_empty) | |||||
{ | |||||
_is_empty = true; | |||||
return; | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError("No variables to save"); | |||||
} | |||||
} | |||||
_is_empty = false; | |||||
_saver_def = _builder._build_internal(_var_list, | |||||
reshape: _reshape, | |||||
sharded: _sharded, | |||||
max_to_keep: _max_to_keep, | |||||
keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours, | |||||
name: _name, | |||||
restore_sequentially: _restore_sequentially, | |||||
filename: checkpoint_path, | |||||
build_save: build_save, | |||||
build_restore: build_restore); | |||||
} | |||||
else if (_saver_def != null && !string.IsNullOrEmpty(_name)) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,102 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class saveable_object_util | |||||
{ | |||||
/// <summary> | |||||
/// Returns the variables and names that will be used for a Saver. | |||||
/// </summary> | |||||
/// <param name="names_to_saveables"></param> | |||||
/// <returns></returns> | |||||
public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables) | |||||
{ | |||||
var names_to_saveables_dict = op_list_to_dict(names_to_saveables); | |||||
var saveables = new List<SaveableObject>(); | |||||
var seen_ops = new List<Tensor>(); | |||||
foreach (var item in names_to_saveables_dict) | |||||
{ | |||||
foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||||
_add_saveable(saveables, seen_ops, converted_saveable_object); | |||||
} | |||||
return saveables.ToArray(); | |||||
} | |||||
private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : SaveableObject | |||||
{ | |||||
if (seen_ops.Contains(saveable.op)) | |||||
throw new ValueError($"The same saveable will be restored with two names: {saveable.name}"); | |||||
saveables.Add(saveable); | |||||
seen_ops.Add(saveable.op); | |||||
} | |||||
/// <summary> | |||||
/// Create `SaveableObject`s from an operation. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<SaveableObject> saveable_objects_for_op(Tensor op, string name) | |||||
{ | |||||
if (false) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
ops.init_scope(); | |||||
var variable = ops.internal_convert_to_tensor(op, as_ref: true); | |||||
if (variable.op.type == "VariableV2") | |||||
yield return new ReferenceVariableSaveable(variable, "", name); | |||||
} | |||||
} | |||||
public static Dictionary<string, Tensor> op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true) | |||||
{ | |||||
op_list = op_list.OrderBy(x => x.name).ToArray(); | |||||
var names_to_saveables = new Dictionary<string, Tensor>(); | |||||
foreach(var var in op_list) | |||||
{ | |||||
if (false) | |||||
{ | |||||
throw new NotImplementedException("op_list_to_dict"); | |||||
} | |||||
else | |||||
{ | |||||
if(false) // eager | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
string name = ""; | |||||
Tensor tensor = null; | |||||
if (convert_variable_to_tensor) | |||||
{ | |||||
tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||||
} | |||||
if (var.op.type == "ReadVariableOp") | |||||
name = var.op.inputs[0].op.Name; | |||||
else | |||||
name = var.op.Name; | |||||
if (names_to_saveables.ContainsKey(name)) | |||||
throw new ValueError($"At least two variables have the same name: {name}"); | |||||
names_to_saveables[name] = tensor; | |||||
} | |||||
} | |||||
} | |||||
return names_to_saveables; | |||||
} | |||||
} | |||||
} |
@@ -8,10 +8,9 @@ namespace Tensorflow | |||||
{ | { | ||||
public static class train | public static class train | ||||
{ | { | ||||
public static Optimizer GradientDescentOptimizer(double learning_rate) | |||||
{ | |||||
return new GradientDescentOptimizer(learning_rate); | |||||
} | |||||
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | |||||
public static Saver Saver() => new Saver(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -16,6 +16,26 @@ namespace Tensorflow | |||||
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | ||||
} | } | ||||
/// <summary> | |||||
/// Returns all variables and `SaveableObject`s that must be checkpointed. | |||||
/// </summary> | |||||
/// <param name="scope"></param> | |||||
/// <returns></returns> | |||||
public static RefVariable[] _all_saveable_objects(string scope = "") | |||||
{ | |||||
var all = new List<RefVariable>(); | |||||
var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
if(collection != null) | |||||
all.AddRange(collection as List<RefVariable>); | |||||
collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
if (collection != null) | |||||
all.AddRange(collection as List<RefVariable>); | |||||
return all.ToArray(); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns global variables. | /// Returns global variables. | ||||
/// </summary> | /// </summary> | ||||
@@ -27,6 +27,11 @@ namespace Tensorflow | |||||
/// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
/// </summary> | /// </summary> | ||||
public static string GLOBAL_VARIABLES = "variables"; | public static string GLOBAL_VARIABLES = "variables"; | ||||
/// <summary> | |||||
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||||
/// </summary> | |||||
public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -387,6 +387,10 @@ namespace Tensorflow | |||||
{ | { | ||||
case "Tensor": | case "Tensor": | ||||
return value as Tensor; | return value as Tensor; | ||||
case "String": | |||||
return constant_op.constant(Convert.ToString(value), name); | |||||
case "String[]": | |||||
return constant_op.constant(value as string[], name); | |||||
case "Int32": | case "Int32": | ||||
return constant_op.constant(Convert.ToInt32(value), name); | return constant_op.constant(Convert.ToInt32(value), name); | ||||
case "Double": | case "Double": | ||||
@@ -7,7 +7,7 @@ using Tensorflow; | |||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class TrainSaverTest | |||||
public class TrainSaverTest : Python | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void Save() | public void Save() | ||||
@@ -20,6 +20,14 @@ namespace TensorFlowNET.UnitTest | |||||
// Add an op to initialize the variables. | // Add an op to initialize the variables. | ||||
var init_op = tf.global_variables_initializer(); | var init_op = tf.global_variables_initializer(); | ||||
// Add ops to save and restore all the variables. | |||||
var saver = tf.train.Saver(); | |||||
with<Session>(tf.Session(), sess => | |||||
{ | |||||
sess.run(init_op); | |||||
}); | |||||
} | } | ||||
} | } | ||||
} | } |