diff --git a/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs b/src/TensorFlowNET.Core/ITensorOrOperation.cs similarity index 76% rename from src/TensorFlowNET.Core/IReturnTensorOrOperation.cs rename to src/TensorFlowNET.Core/ITensorOrOperation.cs index 51c840ac..c29713b9 100644 --- a/src/TensorFlowNET.Core/IReturnTensorOrOperation.cs +++ b/src/TensorFlowNET.Core/ITensorOrOperation.cs @@ -8,7 +8,8 @@ namespace Tensorflow /// in order to limit function return value /// is Tensor or Operation /// - public interface IReturnTensorOrOperation + public interface ITensorOrOperation { + string Device { get; } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 4d268c6d..e6450393 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -7,7 +7,7 @@ using System.Text; namespace Tensorflow { - public partial class Operation : IReturnTensorOrOperation + public partial class Operation : ITensorOrOperation { private readonly IntPtr _handle; // _c_op in python diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 6e58df92..ebe9a0a4 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -7,20 +7,20 @@ namespace Tensorflow { public class control_flow_ops { - public static Operation group(Operation[] inputs, string name = "") + public static Operation group(T[] inputs, string name = "") where T : ITensorOrOperation { return Python.with(new ops.name_scope(name, "group_deps", inputs), scope => { name = scope; // Sorts *inputs according to their devices. - var ops_on_device = new Dictionary>(); + var ops_on_device = new Dictionary>(); foreach (var inp in inputs) { if (ops_on_device.ContainsKey(inp.Device)) ops_on_device[inp.Device].Add(inp); else - ops_on_device[inp.Device] = new List { inp }; + ops_on_device[inp.Device] = new List { inp }; } // 1-level tree. The root node is the returned NoOp node. @@ -28,12 +28,15 @@ namespace Tensorflow { var dev = ops_on_device.Keys.First(); var deps = ops_on_device.Values.First(); - return _GroupControlDeps(dev, deps.ToArray(), name); + if (typeof(T).Name == "Operation") + return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); + else + throw new NotImplementedException("control_flow_ops.group"); } // 2-level tree. The root node is the returned NoOp node. // deps contains 1 NoOp node for each device. - return null; + throw new NotImplementedException("control_flow_ops.group"); }); } diff --git a/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs index ce57e834..6c33a80e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_io_ops.py.cs @@ -14,5 +14,12 @@ namespace Tensorflow return _op; } + + public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "") + { + var _op = _op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); + + return _op.outputs; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index ab7d3304..2f5b4542 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public partial class Tensor : IDisposable, IReturnTensorOrOperation + public partial class Tensor : IDisposable, ITensorOrOperation { private readonly IntPtr _handle; @@ -175,6 +175,9 @@ namespace Tensorflow } public Operation[] Consumers => consumers(); + + public string Device => op.Device; + public Operation[] consumers() { var output = _as_tf_output(); diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index edbb8010..ccd963d6 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -42,7 +42,18 @@ namespace Tensorflow public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) { - throw new NotImplementedException(); + var names = new List(); + var slices = new List(); + var dtypes = new List(); + foreach (var saveable in saveables) + foreach (var spec in saveable.specs) + { + names.Add(spec.name); + slices.Add(spec.slice_spec); + dtypes.Add(spec.dtype); + } + + return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); } public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, @@ -83,6 +94,9 @@ namespace Tensorflow { if (build_save) _AddSaveOps(filename_tensor, saveables); + + if (build_restore) + _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); } }); @@ -94,5 +108,42 @@ namespace Tensorflow var save = save_op(filename_tensor, saveables); return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); } + + /// + /// Add operations to restore saveables. + /// + /// + /// + /// + /// + /// + /// + /// An Operation that restores the variables. + public Operation _AddRestoreOps(Tensor filename_tensor, + SaveableObject[] saveables, + bool restore_sequentially, + bool reshape, + int preferred_shard = -1, + string name = "restore_all") + { + var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially); + var assign_ops = new List(); + int idx = 0; + + foreach(var saveable in saveables) + { + List shapes = null; + if (reshape) + { + throw new NotImplementedException("_AddRestoreOps"); + } + + var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); + idx += saveable.specs.Length; + assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray())); + } + + return control_flow_ops.group(assign_ops.ToArray(), name: name); + } } } diff --git a/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs index 79be269b..e381cf14 100644 --- a/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs @@ -27,5 +27,13 @@ namespace Tensorflow this.specs = specs; this.name = name; } + + public virtual Tensor restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null) + { + var restored_tensor = restored_tensors[0]; + return gen_state_ops.assign(op, + restored_tensor, + validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined()); + } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 39a8d909..b908f657 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -186,7 +186,7 @@ namespace Tensorflow /// the assignment has completed. /// public T assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) - where T : IReturnTensorOrOperation + where T : ITensorOrOperation { var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); if (read_value)