Browse Source

Add many Saveable classes.

tags/v0.8.0
haiping008 6 years ago
parent
commit
e0f1ac0415
20 changed files with 992 additions and 13 deletions
  1. +8
    -1
      docs/source/Train.md
  2. +21
    -0
      src/TensorFlowNET.Core/IPyClass.cs
  3. +42
    -6
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  4. +13
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +18
    -0
      src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs
  6. +401
    -0
      src/TensorFlowNET.Core/Protobuf/Saver.cs
  7. +15
    -1
      src/TensorFlowNET.Core/Python.cs
  8. +98
    -0
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  9. +14
    -0
      src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs
  10. +24
    -0
      src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
  11. +19
    -0
      src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs
  12. +32
    -0
      src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs
  13. +31
    -0
      src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
  14. +113
    -0
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  15. +102
    -0
      src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs
  16. +3
    -4
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  17. +20
    -0
      src/TensorFlowNET.Core/Variables/variables.py.cs
  18. +5
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  19. +4
    -0
      src/TensorFlowNET.Core/ops.py.cs
  20. +9
    -1
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 8
- 1
docs/source/Train.md View File

@@ -2,4 +2,11 @@

### Saver

The `tf.train.saver` class provides methods to save and restore models.
The `tf.train.saver` class provides methods to save and restore models.



### Saver Builder

##### Bulk Saver Builder


+ 21
- 0
src/TensorFlowNET.Core/IPyClass.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface IPyClass
{
/// <summary>
/// Called when the instance is created.
/// </summary>
/// <param name="args"></param>
void __init__(IPyClass self, dynamic args);

void __enter__(IPyClass self);

void __exit__(IPyClass self);

void __del__(IPyClass self);
}
}

+ 42
- 6
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -116,6 +116,10 @@ namespace Tensorflow
values = new Tensor[] { keywords[input_name] as Tensor };
}

inputs.AddRange(values as Tensor[]);
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
input_types.AddRange(base_types);

if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
if (attrs.ContainsKey(input_arg.NumberAttr))
@@ -144,10 +148,32 @@ namespace Tensorflow
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr);
}
}
else if (!string.IsNullOrEmpty(input_arg.TypeAttr))
{
var attr_value = base_types[0];
if (attrs.ContainsKey(input_arg.TypeAttr))
{

inputs.AddRange(values as Tensor[]);
base_types.AddRange((values as Tensor[]).Select(x => x.dtype.as_base_dtype()));
input_types.AddRange(base_types);
}
else
{
attrs[input_arg.TypeAttr] = attr_value;
inferred_from[input_arg.TypeAttr] = input_name;
}
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
var attr_value = base_types;
if (attrs.ContainsKey(input_arg.TypeListAttr))
{

}
else
{
attrs[input_arg.TypeListAttr] = attr_value;
inferred_from[input_arg.TypeListAttr] = input_name;
}
}
}

// Process remaining attrs
@@ -213,6 +239,11 @@ namespace Tensorflow
case "type":
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
break;
case "list(type)":
if (attr_value.List == null)
attr_value.List = new AttrValue.Types.ListValue();
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def)));
break;
case "bool":
attr_value.B = (bool)value;
break;
@@ -225,9 +256,14 @@ namespace Tensorflow
throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}.");
break;
case "shape":
attr_value.Shape = value == null ?
attr_def.DefaultValue.Shape :
tensor_util.as_shape((long[])value);
if (value == null && attr_def.DefaultValue != null)
attr_value.Shape = attr_def.DefaultValue.Shape;

if(value is TensorShape val1)
attr_value.Shape = val1.as_proto();
else if(value is long[] val2)
attr_value.Shape = tensor_util.as_shape(val2);

break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");


+ 13
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -106,6 +106,19 @@ namespace Tensorflow
throw new NotImplementedException("where");
}

/// <summary>
/// A placeholder op that passes through `input` when its output is not fed.
/// </summary>
/// <param name="input">The default value to produce when output is not fed.</param>
/// <param name="shape"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor placeholder_with_default<T>(T input, TensorShape shape, string name = "")
{
var _op = _op_def_lib._apply_op_helper("PlaceholderWithDefault", name, new { input, shape, name });
return _op.outputs[0];
}

public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "")
{
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e });


+ 18
- 0
src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gen_io_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = "")
{
var _op = _op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });

return _op;
}
}
}

+ 401
- 0
src/TensorFlowNET.Core/Protobuf/Saver.cs View File

@@ -0,0 +1,401 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: saver.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from saver.proto</summary>
public static partial class SaverReflection {

#region Descriptor
/// <summary>File descriptor for saver.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static SaverReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CgtzYXZlci5wcm90bxIKdGVuc29yZmxvdyKeAgoIU2F2ZXJEZWYSHAoUZmls",
"ZW5hbWVfdGVuc29yX25hbWUYASABKAkSGAoQc2F2ZV90ZW5zb3JfbmFtZRgC",
"IAEoCRIXCg9yZXN0b3JlX29wX25hbWUYAyABKAkSEwoLbWF4X3RvX2tlZXAY",
"BCABKAUSDwoHc2hhcmRlZBgFIAEoCBIlCh1rZWVwX2NoZWNrcG9pbnRfZXZl",
"cnlfbl9ob3VycxgGIAEoAhI9Cgd2ZXJzaW9uGAcgASgOMiwudGVuc29yZmxv",
"dy5TYXZlckRlZi5DaGVja3BvaW50Rm9ybWF0VmVyc2lvbiI1ChdDaGVja3Bv",
"aW50Rm9ybWF0VmVyc2lvbhIKCgZMRUdBQ1kQABIGCgJWMRABEgYKAlYyEAJC",
"ZQoTb3JnLnRlbnNvcmZsb3cudXRpbEILU2F2ZXJQcm90b3NQAVo8Z2l0aHVi",
"LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl",
"L3Byb3RvYnVm+AEBYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaverDef), global::Tensorflow.SaverDef.Parser, new[]{ "FilenameTensorName", "SaveTensorName", "RestoreOpName", "MaxToKeep", "Sharded", "KeepCheckpointEveryNHours", "Version" }, null, new[]{ typeof(global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) }, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// Protocol buffer representing the configuration of a Saver.
/// </summary>
public sealed partial class SaverDef : pb::IMessage<SaverDef> {
private static readonly pb::MessageParser<SaverDef> _parser = new pb::MessageParser<SaverDef>(() => new SaverDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<SaverDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.SaverReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SaverDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SaverDef(SaverDef other) : this() {
filenameTensorName_ = other.filenameTensorName_;
saveTensorName_ = other.saveTensorName_;
restoreOpName_ = other.restoreOpName_;
maxToKeep_ = other.maxToKeep_;
sharded_ = other.sharded_;
keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_;
version_ = other.version_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SaverDef Clone() {
return new SaverDef(this);
}

/// <summary>Field number for the "filename_tensor_name" field.</summary>
public const int FilenameTensorNameFieldNumber = 1;
private string filenameTensorName_ = "";
/// <summary>
/// The name of the tensor in which to specify the filename when saving or
/// restoring a model checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string FilenameTensorName {
get { return filenameTensorName_; }
set {
filenameTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "save_tensor_name" field.</summary>
public const int SaveTensorNameFieldNumber = 2;
private string saveTensorName_ = "";
/// <summary>
/// The operation to run when saving a model checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string SaveTensorName {
get { return saveTensorName_; }
set {
saveTensorName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "restore_op_name" field.</summary>
public const int RestoreOpNameFieldNumber = 3;
private string restoreOpName_ = "";
/// <summary>
/// The operation to run when restoring a model checkpoint.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string RestoreOpName {
get { return restoreOpName_; }
set {
restoreOpName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "max_to_keep" field.</summary>
public const int MaxToKeepFieldNumber = 4;
private int maxToKeep_;
/// <summary>
/// Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int MaxToKeep {
get { return maxToKeep_; }
set {
maxToKeep_ = value;
}
}

/// <summary>Field number for the "sharded" field.</summary>
public const int ShardedFieldNumber = 5;
private bool sharded_;
/// <summary>
/// Shard the save files, one per device that has Variable nodes.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Sharded {
get { return sharded_; }
set {
sharded_ = value;
}
}

/// <summary>Field number for the "keep_checkpoint_every_n_hours" field.</summary>
public const int KeepCheckpointEveryNHoursFieldNumber = 6;
private float keepCheckpointEveryNHours_;
/// <summary>
/// How often to keep an additional checkpoint. If not specified, only the last
/// "max_to_keep" checkpoints are kept; if specified, in addition to keeping
/// the last "max_to_keep" checkpoints, an additional checkpoint will be kept
/// for every n hours of training.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public float KeepCheckpointEveryNHours {
get { return keepCheckpointEveryNHours_; }
set {
keepCheckpointEveryNHours_ = value;
}
}

/// <summary>Field number for the "version" field.</summary>
public const int VersionFieldNumber = 7;
private global::Tensorflow.SaverDef.Types.CheckpointFormatVersion version_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.SaverDef.Types.CheckpointFormatVersion Version {
get { return version_; }
set {
version_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as SaverDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(SaverDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (FilenameTensorName != other.FilenameTensorName) return false;
if (SaveTensorName != other.SaveTensorName) return false;
if (RestoreOpName != other.RestoreOpName) return false;
if (MaxToKeep != other.MaxToKeep) return false;
if (Sharded != other.Sharded) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false;
if (Version != other.Version) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (FilenameTensorName.Length != 0) hash ^= FilenameTensorName.GetHashCode();
if (SaveTensorName.Length != 0) hash ^= SaveTensorName.GetHashCode();
if (RestoreOpName.Length != 0) hash ^= RestoreOpName.GetHashCode();
if (MaxToKeep != 0) hash ^= MaxToKeep.GetHashCode();
if (Sharded != false) hash ^= Sharded.GetHashCode();
if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours);
if (Version != 0) hash ^= Version.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (FilenameTensorName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(FilenameTensorName);
}
if (SaveTensorName.Length != 0) {
output.WriteRawTag(18);
output.WriteString(SaveTensorName);
}
if (RestoreOpName.Length != 0) {
output.WriteRawTag(26);
output.WriteString(RestoreOpName);
}
if (MaxToKeep != 0) {
output.WriteRawTag(32);
output.WriteInt32(MaxToKeep);
}
if (Sharded != false) {
output.WriteRawTag(40);
output.WriteBool(Sharded);
}
if (KeepCheckpointEveryNHours != 0F) {
output.WriteRawTag(53);
output.WriteFloat(KeepCheckpointEveryNHours);
}
if (Version != 0) {
output.WriteRawTag(56);
output.WriteEnum((int) Version);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (FilenameTensorName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(FilenameTensorName);
}
if (SaveTensorName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(SaveTensorName);
}
if (RestoreOpName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(RestoreOpName);
}
if (MaxToKeep != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxToKeep);
}
if (Sharded != false) {
size += 1 + 1;
}
if (KeepCheckpointEveryNHours != 0F) {
size += 1 + 4;
}
if (Version != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Version);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(SaverDef other) {
if (other == null) {
return;
}
if (other.FilenameTensorName.Length != 0) {
FilenameTensorName = other.FilenameTensorName;
}
if (other.SaveTensorName.Length != 0) {
SaveTensorName = other.SaveTensorName;
}
if (other.RestoreOpName.Length != 0) {
RestoreOpName = other.RestoreOpName;
}
if (other.MaxToKeep != 0) {
MaxToKeep = other.MaxToKeep;
}
if (other.Sharded != false) {
Sharded = other.Sharded;
}
if (other.KeepCheckpointEveryNHours != 0F) {
KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours;
}
if (other.Version != 0) {
Version = other.Version;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
FilenameTensorName = input.ReadString();
break;
}
case 18: {
SaveTensorName = input.ReadString();
break;
}
case 26: {
RestoreOpName = input.ReadString();
break;
}
case 32: {
MaxToKeep = input.ReadInt32();
break;
}
case 40: {
Sharded = input.ReadBool();
break;
}
case 53: {
KeepCheckpointEveryNHours = input.ReadFloat();
break;
}
case 56: {
version_ = (global::Tensorflow.SaverDef.Types.CheckpointFormatVersion) input.ReadEnum();
break;
}
}
}
}

#region Nested types
/// <summary>Container for nested types declared in the SaverDef message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static partial class Types {
/// <summary>
/// A version number that identifies a different on-disk checkpoint format.
/// Usually, each subclass of BaseSaverBuilder works with a particular
/// version/format. However, it is possible that the same builder may be
/// upgraded to support a newer checkpoint format in the future.
/// </summary>
public enum CheckpointFormatVersion {
/// <summary>
/// Internal legacy format.
/// </summary>
[pbr::OriginalName("LEGACY")] Legacy = 0,
/// <summary>
/// Deprecated format: tf.Saver() which works with tensorflow::table::Table.
/// </summary>
[pbr::OriginalName("V1")] V1 = 1,
/// <summary>
/// Current format: more efficient.
/// </summary>
[pbr::OriginalName("V2")] V2 = 2,
}

}
#endregion

}

#endregion

}

#endregion Designer generated code

+ 15
- 1
src/TensorFlowNET.Core/Python.cs View File

@@ -15,6 +15,15 @@ namespace Tensorflow
Console.WriteLine(obj.ToString());
}

public static T New<T>(object args) where T : IPyClass
{
var instance = Activator.CreateInstance<T>();

instance.__init__(instance, args);

return instance;
}

public static void with(IPython py, Action<IPython> action)
{
try
@@ -63,7 +72,7 @@ namespace Tensorflow
catch (Exception ex)
{
Console.WriteLine(ex.ToString());
throw ex;
return default(TOut);
}
finally
{
@@ -97,4 +106,9 @@ namespace Tensorflow

void __exit__();
}

public class PyObject<T> where T : IPyClass
{
public T Instance { get; set; }
}
}

+ 98
- 0
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -0,0 +1,98 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public class BaseSaverBuilder
{
protected int _write_version;

public BaseSaverBuilder(int write_version = 2)
{
_write_version = write_version;
}

public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables)
{
var tensor_names = new List<string>();
var tensors = new List<Tensor>();
var tensor_slices = new List<string>();

foreach (var saveable in saveables)
{
foreach(var spec in saveable.specs)
{
tensor_names.Add(spec.name);
tensors.Add(spec.tensor);
tensor_slices.Add(spec.slice_spec);
}
}

if (_write_version == 2)
{
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
}
else
{
throw new NotImplementedException("_write_version v1");
}
}

public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially)
{
throw new NotImplementedException();
}

public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
double keep_checkpoint_every_n_hours = 10000,
string name = "",
bool restore_sequentially = false,
string filename = "model",
bool build_save = true,
bool build_restore = true)
{
if (!build_save || !build_restore)
throw new ValueError("save and restore operations need to be built together " +
" when eager execution is not enabled.");

var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables);

if (max_to_keep < 0)
max_to_keep = 0;

Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
{
name = scope;

// Add a placeholder string tensor for the filename.
var filename_tensor = gen_array_ops.placeholder_with_default( string.IsNullOrEmpty(filename) ? "model" : filename, shape: new TensorShape(), name: "filename");
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new TensorShape(), name: "Const");
// Keep the name "Const" for backwards compatibility.

// Add the save ops.
if (sharded)
{

}
else
{
if (build_save)
_AddSaveOps(filename_tensor, saveables);
}
});

throw new NotImplementedException("");
}

public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables)
{
var save = save_op(filename_tensor, saveables);
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor);
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder
{
public BulkSaverBuilder(int write_version = 2) : base(write_version)
{

}
}
}

+ 24
- 0
src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface ISaverBuilder
{
Operation save_op(Tensor filename_tensor, SaveableObject[] saveables);

Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially);

SaverDef _build_internal(RefVariable[] names_to_saveables,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
double keep_checkpoint_every_n_hours = 10000,
string name = "",
bool restore_sequentially = false,
string filename = "model",
bool build_save = true,
bool build_restore = true);
}
}

+ 19
- 0
src/TensorFlowNET.Core/Train/Saving/ReferenceVariableSaveable.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class ReferenceVariableSaveable : SaveableObject
{
private SaveSpec _spec;

public ReferenceVariableSaveable(Tensor var, string slice_spec, string name)
{
_spec = new SaveSpec(var, slice_spec, name, dtype: var.dtype);
op = var;
specs = new SaveSpec[] { _spec };
this.name = name;
}
}
}

+ 32
- 0
src/TensorFlowNET.Core/Train/Saving/SaveSpec.cs View File

@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Class used to describe tensor slices that need to be saved.
/// </summary>
public class SaveSpec
{
private Tensor _tensor;
public Tensor tensor => _tensor;

private string _slice_spec;
public string slice_spec => _slice_spec;

private string _name;
public string name => _name;

private TF_DataType _dtype;
public TF_DataType dtype => _dtype;

public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
_tensor = tensor;
_slice_spec = slice_spec;
_name = name;
_dtype = dtype;
}
}
}

+ 31
- 0
src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class SaveableObject
{
public Tensor op;
public SaveSpec[] specs;
public string name;
public string device;

public SaveableObject()
{

}

public SaveableObject(Tensor var, string slice_spec, string name)
{

}

public SaveableObject(Tensor op, SaveSpec[] specs, string name)
{
this.op = op;
this.specs = specs;
this.name = name;
}
}
}

+ 113
- 0
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -0,0 +1,113 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Saves and restores variables.
/// </summary>
public class Saver
{
private RefVariable[] _var_list;
private bool _reshape;
private bool _sharded;
private int _max_to_keep;
private double _keep_checkpoint_every_n_hours;
private string _name;
private bool _restore_sequentially;
private SaverDef _saver_def;
private ISaverBuilder _builder;
private bool _allow_empty;
private bool _is_built;
private int _write_version;
private bool _pad_step_number;
private string _filename;
private bool _is_empty;

public Saver(RefVariable[] var_list = null,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
double keep_checkpoint_every_n_hours = 10000,
string name = "",
bool restore_sequentially = false,
SaverDef saver_def = null,
ISaverBuilder builder = null,
bool defer_build = false,
bool allow_empty = false,
int write_version = 2,
bool pad_step_number = false,
bool save_relative_paths = false,
string filename = "")
{
_var_list = var_list;
_reshape = reshape;
_sharded = sharded;
_max_to_keep = max_to_keep;
_keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours;
_name = name;
_restore_sequentially = restore_sequentially;
_builder = builder;
_is_built = false;
_allow_empty = allow_empty;
_write_version = write_version;
_pad_step_number = pad_step_number;

if (!defer_build)
build();
}

public void build()
{
_build(_filename, build_save: true, build_restore: true);
}

private void _build(string checkpoint_path, bool build_save, bool build_restore)
{
if (_is_built) return;

_is_built = true;

if (_saver_def == null)
{
if (_builder == null)
_builder = new BulkSaverBuilder(_write_version);

if (_var_list == null)
_var_list = variables._all_saveable_objects();

if (_var_list == null || _var_list.Length == 0)
{
if (_allow_empty)
{
_is_empty = true;
return;
}
else
{
throw new ValueError("No variables to save");
}
}
_is_empty = false;

_saver_def = _builder._build_internal(_var_list,
reshape: _reshape,
sharded: _sharded,
max_to_keep: _max_to_keep,
keep_checkpoint_every_n_hours: _keep_checkpoint_every_n_hours,
name: _name,
restore_sequentially: _restore_sequentially,
filename: checkpoint_path,
build_save: build_save,
build_restore: build_restore);
}
else if (_saver_def != null && !string.IsNullOrEmpty(_name))
{
throw new NotImplementedException("");
}

}
}
}

+ 102
- 0
src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs View File

@@ -0,0 +1,102 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
{
public class saveable_object_util
{
/// <summary>
/// Returns the variables and names that will be used for a Saver.
/// </summary>
/// <param name="names_to_saveables"></param>
/// <returns></returns>
public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables)
{
var names_to_saveables_dict = op_list_to_dict(names_to_saveables);
var saveables = new List<SaveableObject>();
var seen_ops = new List<Tensor>();

foreach (var item in names_to_saveables_dict)
{
foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key))
_add_saveable(saveables, seen_ops, converted_saveable_object);
}
return saveables.ToArray();
}

private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : SaveableObject
{
if (seen_ops.Contains(saveable.op))
throw new ValueError($"The same saveable will be restored with two names: {saveable.name}");

saveables.Add(saveable);
seen_ops.Add(saveable.op);
}

/// <summary>
/// Create `SaveableObject`s from an operation.
/// </summary>
/// <param name="op"></param>
/// <param name="name"></param>
/// <returns></returns>
public static IEnumerable<SaveableObject> saveable_objects_for_op(Tensor op, string name)
{
if (false)
{

}
else
{
ops.init_scope();
var variable = ops.internal_convert_to_tensor(op, as_ref: true);
if (variable.op.type == "VariableV2")
yield return new ReferenceVariableSaveable(variable, "", name);
}
}

public static Dictionary<string, Tensor> op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true)
{
op_list = op_list.OrderBy(x => x.name).ToArray();
var names_to_saveables = new Dictionary<string, Tensor>();

foreach(var var in op_list)
{
if (false)
{
throw new NotImplementedException("op_list_to_dict");
}
else
{
if(false) // eager
{

}
else
{
string name = "";
Tensor tensor = null;

if (convert_variable_to_tensor)
{
tensor = ops.internal_convert_to_tensor(var, as_ref: true);
}

if (var.op.type == "ReadVariableOp")
name = var.op.inputs[0].op.Name;
else
name = var.op.Name;

if (names_to_saveables.ContainsKey(name))
throw new ValueError($"At least two variables have the same name: {name}");

names_to_saveables[name] = tensor;
}
}
}

return names_to_saveables;
}
}
}

+ 3
- 4
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -8,10 +8,9 @@ namespace Tensorflow
{
public static class train
{
public static Optimizer GradientDescentOptimizer(double learning_rate)
{
return new GradientDescentOptimizer(learning_rate);
}
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate);

public static Saver Saver() => new Saver();
}
}
}

+ 20
- 0
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -16,6 +16,26 @@ namespace Tensorflow
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
}

/// <summary>
/// Returns all variables and `SaveableObject`s that must be checkpointed.
/// </summary>
/// <param name="scope"></param>
/// <returns></returns>
public static RefVariable[] _all_saveable_objects(string scope = "")
{
var all = new List<RefVariable>();

var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
if(collection != null)
all.AddRange(collection as List<RefVariable>);

collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope);
if (collection != null)
all.AddRange(collection as List<RefVariable>);

return all.ToArray();
}

/// <summary>
/// Returns global variables.
/// </summary>


+ 5
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -27,6 +27,11 @@ namespace Tensorflow
/// Default collection for all variables, except local ones.
/// </summary>
public static string GLOBAL_VARIABLES = "variables";

/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>
public static string SAVEABLE_OBJECTS = "saveable_objects";
}
}
}

+ 4
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -387,6 +387,10 @@ namespace Tensorflow
{
case "Tensor":
return value as Tensor;
case "String":
return constant_op.constant(Convert.ToString(value), name);
case "String[]":
return constant_op.constant(value as string[], name);
case "Int32":
return constant_op.constant(Convert.ToInt32(value), name);
case "Double":


+ 9
- 1
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -7,7 +7,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TrainSaverTest
public class TrainSaverTest : Python
{
[TestMethod]
public void Save()
@@ -20,6 +20,14 @@ namespace TensorFlowNET.UnitTest

// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();

// Add ops to save and restore all the variables.
var saver = tf.train.Saver();

with<Session>(tf.Session(), sess =>
{
sess.run(init_op);
});
}
}
}

Loading…
Cancel
Save