Browse Source

restore SaveableObject.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
459abbcb34
8 changed files with 83 additions and 10 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/ITensorOrOperation.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +8
    -5
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  4. +7
    -0
      src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs
  5. +4
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +52
    -1
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  7. +8
    -0
      src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs

src/TensorFlowNET.Core/IReturnTensorOrOperation.cs → src/TensorFlowNET.Core/ITensorOrOperation.cs View File

@@ -8,7 +8,8 @@ namespace Tensorflow
/// in order to limit function return value
/// is Tensor or Operation
/// </summary>
public interface IReturnTensorOrOperation
public interface ITensorOrOperation
{
string Device { get; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -7,7 +7,7 @@ using System.Text;

namespace Tensorflow
{
public partial class Operation : IReturnTensorOrOperation
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python



+ 8
- 5
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -7,20 +7,20 @@ namespace Tensorflow
{
public class control_flow_ops
{
public static Operation group(Operation[] inputs, string name = "")
public static Operation group<T>(T[] inputs, string name = "") where T : ITensorOrOperation
{
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope =>
{
name = scope;

// Sorts *inputs according to their devices.
var ops_on_device = new Dictionary<string, List<Operation>>();
var ops_on_device = new Dictionary<string, List<T>>();
foreach (var inp in inputs)
{
if (ops_on_device.ContainsKey(inp.Device))
ops_on_device[inp.Device].Add(inp);
else
ops_on_device[inp.Device] = new List<Operation> { inp };
ops_on_device[inp.Device] = new List<T> { inp };
}

// 1-level tree. The root node is the returned NoOp node.
@@ -28,12 +28,15 @@ namespace Tensorflow
{
var dev = ops_on_device.Keys.First();
var deps = ops_on_device.Values.First();
return _GroupControlDeps(dev, deps.ToArray(), name);
if (typeof(T).Name == "Operation")
return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name);
else
throw new NotImplementedException("control_flow_ops.group");
}

// 2-level tree. The root node is the returned NoOp node.
// deps contains 1 NoOp node for each device.
return null;
throw new NotImplementedException("control_flow_ops.group");
});
}



+ 7
- 0
src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs View File

@@ -14,5 +14,12 @@ namespace Tensorflow

return _op;
}

public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "")
{
var _op = _op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });

return _op.outputs;
}
}
}

+ 4
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary>
public partial class Tensor : IDisposable, IReturnTensorOrOperation
public partial class Tensor : IDisposable, ITensorOrOperation
{
private readonly IntPtr _handle;

@@ -175,6 +175,9 @@ namespace Tensorflow
}

public Operation[] Consumers => consumers();

public string Device => op.Device;

public Operation[] consumers()
{
var output = _as_tf_output();


+ 52
- 1
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -42,7 +42,18 @@ namespace Tensorflow

public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially)
{
throw new NotImplementedException();
var names = new List<string>();
var slices = new List<string>();
var dtypes = new List<TF_DataType>();
foreach (var saveable in saveables)
foreach (var spec in saveable.specs)
{
names.Add(spec.name);
slices.Add(spec.slice_spec);
dtypes.Add(spec.dtype);
}

return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
}

public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
@@ -83,6 +94,9 @@ namespace Tensorflow
{
if (build_save)
_AddSaveOps(filename_tensor, saveables);

if (build_restore)
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape);
}
});

@@ -94,5 +108,42 @@ namespace Tensorflow
var save = save_op(filename_tensor, saveables);
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor);
}

/// <summary>
/// Add operations to restore saveables.
/// </summary>
/// <param name="filename_tensor"></param>
/// <param name="saveables"></param>
/// <param name="restore_sequentially"></param>
/// <param name="reshape"></param>
/// <param name="preferred_shard"></param>
/// <param name="name"></param>
/// <returns>An Operation that restores the variables.</returns>
public Operation _AddRestoreOps(Tensor filename_tensor,
SaveableObject[] saveables,
bool restore_sequentially,
bool reshape,
int preferred_shard = -1,
string name = "restore_all")
{
var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially);
var assign_ops = new List<Tensor>();
int idx = 0;

foreach(var saveable in saveables)
{
List<TensorShape> shapes = null;
if (reshape)
{
throw new NotImplementedException("_AddRestoreOps");
}

var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length);
idx += saveable.specs.Length;
assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()));
}

return control_flow_ops.group(assign_ops.ToArray(), name: name);
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs View File

@@ -27,5 +27,13 @@ namespace Tensorflow
this.specs = specs;
this.name = name;
}

public virtual Tensor restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];
return gen_state_ops.assign(op,
restored_tensor,
validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined());
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -186,7 +186,7 @@ namespace Tensorflow
/// the assignment has completed.
/// </returns>
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true)
where T : IReturnTensorOrOperation
where T : ITensorOrOperation
{
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
if (read_value)


Loading…
Cancel
Save