Browse Source

restore SaveableObject.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
459abbcb34
8 changed files with 83 additions and 10 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/ITensorOrOperation.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +8
    -5
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  4. +7
    -0
      src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs
  5. +4
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +52
    -1
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  7. +8
    -0
      src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs

src/TensorFlowNET.Core/IReturnTensorOrOperation.cs → src/TensorFlowNET.Core/ITensorOrOperation.cs View File

@@ -8,7 +8,8 @@ namespace Tensorflow
/// in order to limit function return value /// in order to limit function return value
/// is Tensor or Operation /// is Tensor or Operation
/// </summary> /// </summary>
public interface IReturnTensorOrOperation
public interface ITensorOrOperation
{ {
string Device { get; }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -7,7 +7,7 @@ using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
public partial class Operation : IReturnTensorOrOperation
public partial class Operation : ITensorOrOperation
{ {
private readonly IntPtr _handle; // _c_op in python private readonly IntPtr _handle; // _c_op in python




+ 8
- 5
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -7,20 +7,20 @@ namespace Tensorflow
{ {
public class control_flow_ops public class control_flow_ops
{ {
public static Operation group(Operation[] inputs, string name = "")
public static Operation group<T>(T[] inputs, string name = "") where T : ITensorOrOperation
{ {
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope => return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope =>
{ {
name = scope; name = scope;


// Sorts *inputs according to their devices. // Sorts *inputs according to their devices.
var ops_on_device = new Dictionary<string, List<Operation>>();
var ops_on_device = new Dictionary<string, List<T>>();
foreach (var inp in inputs) foreach (var inp in inputs)
{ {
if (ops_on_device.ContainsKey(inp.Device)) if (ops_on_device.ContainsKey(inp.Device))
ops_on_device[inp.Device].Add(inp); ops_on_device[inp.Device].Add(inp);
else else
ops_on_device[inp.Device] = new List<Operation> { inp };
ops_on_device[inp.Device] = new List<T> { inp };
} }


// 1-level tree. The root node is the returned NoOp node. // 1-level tree. The root node is the returned NoOp node.
@@ -28,12 +28,15 @@ namespace Tensorflow
{ {
var dev = ops_on_device.Keys.First(); var dev = ops_on_device.Keys.First();
var deps = ops_on_device.Values.First(); var deps = ops_on_device.Values.First();
return _GroupControlDeps(dev, deps.ToArray(), name);
if (typeof(T).Name == "Operation")
return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name);
else
throw new NotImplementedException("control_flow_ops.group");
} }


// 2-level tree. The root node is the returned NoOp node. // 2-level tree. The root node is the returned NoOp node.
// deps contains 1 NoOp node for each device. // deps contains 1 NoOp node for each device.
return null;
throw new NotImplementedException("control_flow_ops.group");
}); });
} }




+ 7
- 0
src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs View File

@@ -14,5 +14,12 @@ namespace Tensorflow


return _op; return _op;
} }

public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "")
{
var _op = _op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });

return _op.outputs;
}
} }
} }

+ 4
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary> /// </summary>
public partial class Tensor : IDisposable, IReturnTensorOrOperation
public partial class Tensor : IDisposable, ITensorOrOperation
{ {
private readonly IntPtr _handle; private readonly IntPtr _handle;


@@ -175,6 +175,9 @@ namespace Tensorflow
} }


public Operation[] Consumers => consumers(); public Operation[] Consumers => consumers();

public string Device => op.Device;

public Operation[] consumers() public Operation[] consumers()
{ {
var output = _as_tf_output(); var output = _as_tf_output();


+ 52
- 1
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -42,7 +42,18 @@ namespace Tensorflow


public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially)
{ {
throw new NotImplementedException();
var names = new List<string>();
var slices = new List<string>();
var dtypes = new List<TF_DataType>();
foreach (var saveable in saveables)
foreach (var spec in saveable.specs)
{
names.Add(spec.name);
slices.Add(spec.slice_spec);
dtypes.Add(spec.dtype);
}

return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
} }


public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
@@ -83,6 +94,9 @@ namespace Tensorflow
{ {
if (build_save) if (build_save)
_AddSaveOps(filename_tensor, saveables); _AddSaveOps(filename_tensor, saveables);

if (build_restore)
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape);
} }
}); });


@@ -94,5 +108,42 @@ namespace Tensorflow
var save = save_op(filename_tensor, saveables); var save = save_op(filename_tensor, saveables);
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor);
} }

/// <summary>
/// Add operations to restore saveables.
/// </summary>
/// <param name="filename_tensor"></param>
/// <param name="saveables"></param>
/// <param name="restore_sequentially"></param>
/// <param name="reshape"></param>
/// <param name="preferred_shard"></param>
/// <param name="name"></param>
/// <returns>An Operation that restores the variables.</returns>
public Operation _AddRestoreOps(Tensor filename_tensor,
SaveableObject[] saveables,
bool restore_sequentially,
bool reshape,
int preferred_shard = -1,
string name = "restore_all")
{
var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially);
var assign_ops = new List<Tensor>();
int idx = 0;

foreach(var saveable in saveables)
{
List<TensorShape> shapes = null;
if (reshape)
{
throw new NotImplementedException("_AddRestoreOps");
}

var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length);
idx += saveable.specs.Length;
assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()));
}

return control_flow_ops.group(assign_ops.ToArray(), name: name);
}
} }
} }

+ 8
- 0
src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs View File

@@ -27,5 +27,13 @@ namespace Tensorflow
this.specs = specs; this.specs = specs;
this.name = name; this.name = name;
} }

public virtual Tensor restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];
return gen_state_ops.assign(op,
restored_tensor,
validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined());
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -186,7 +186,7 @@ namespace Tensorflow
/// the assignment has completed. /// the assignment has completed.
/// </returns> /// </returns>
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true)
where T : IReturnTensorOrOperation
where T : ITensorOrOperation
{ {
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
if (read_value) if (read_value)


Loading…
Cancel
Save