@@ -8,7 +8,8 @@ namespace Tensorflow | |||||
/// in order to limit function return value | /// in order to limit function return value | ||||
/// is Tensor or Operation | /// is Tensor or Operation | ||||
/// </summary> | /// </summary> | ||||
public interface IReturnTensorOrOperation | |||||
public interface ITensorOrOperation | |||||
{ | { | ||||
string Device { get; } | |||||
} | } | ||||
} | } |
@@ -7,7 +7,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class Operation : IReturnTensorOrOperation | |||||
public partial class Operation : ITensorOrOperation | |||||
{ | { | ||||
private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
@@ -7,20 +7,20 @@ namespace Tensorflow | |||||
{ | { | ||||
public class control_flow_ops | public class control_flow_ops | ||||
{ | { | ||||
public static Operation group(Operation[] inputs, string name = "") | |||||
public static Operation group<T>(T[] inputs, string name = "") where T : ITensorOrOperation | |||||
{ | { | ||||
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope => | return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope => | ||||
{ | { | ||||
name = scope; | name = scope; | ||||
// Sorts *inputs according to their devices. | // Sorts *inputs according to their devices. | ||||
var ops_on_device = new Dictionary<string, List<Operation>>(); | |||||
var ops_on_device = new Dictionary<string, List<T>>(); | |||||
foreach (var inp in inputs) | foreach (var inp in inputs) | ||||
{ | { | ||||
if (ops_on_device.ContainsKey(inp.Device)) | if (ops_on_device.ContainsKey(inp.Device)) | ||||
ops_on_device[inp.Device].Add(inp); | ops_on_device[inp.Device].Add(inp); | ||||
else | else | ||||
ops_on_device[inp.Device] = new List<Operation> { inp }; | |||||
ops_on_device[inp.Device] = new List<T> { inp }; | |||||
} | } | ||||
// 1-level tree. The root node is the returned NoOp node. | // 1-level tree. The root node is the returned NoOp node. | ||||
@@ -28,12 +28,15 @@ namespace Tensorflow | |||||
{ | { | ||||
var dev = ops_on_device.Keys.First(); | var dev = ops_on_device.Keys.First(); | ||||
var deps = ops_on_device.Values.First(); | var deps = ops_on_device.Values.First(); | ||||
return _GroupControlDeps(dev, deps.ToArray(), name); | |||||
if (typeof(T).Name == "Operation") | |||||
return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); | |||||
else | |||||
throw new NotImplementedException("control_flow_ops.group"); | |||||
} | } | ||||
// 2-level tree. The root node is the returned NoOp node. | // 2-level tree. The root node is the returned NoOp node. | ||||
// deps contains 1 NoOp node for each device. | // deps contains 1 NoOp node for each device. | ||||
return null; | |||||
throw new NotImplementedException("control_flow_ops.group"); | |||||
}); | }); | ||||
} | } | ||||
@@ -14,5 +14,12 @@ namespace Tensorflow | |||||
return _op; | return _op; | ||||
} | } | ||||
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||||
return _op.outputs; | |||||
} | |||||
} | } | ||||
} | } |
@@ -12,7 +12,7 @@ namespace Tensorflow | |||||
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
/// </summary> | /// </summary> | ||||
public partial class Tensor : IDisposable, IReturnTensorOrOperation | |||||
public partial class Tensor : IDisposable, ITensorOrOperation | |||||
{ | { | ||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
@@ -175,6 +175,9 @@ namespace Tensorflow | |||||
} | } | ||||
public Operation[] Consumers => consumers(); | public Operation[] Consumers => consumers(); | ||||
public string Device => op.Device; | |||||
public Operation[] consumers() | public Operation[] consumers() | ||||
{ | { | ||||
var output = _as_tf_output(); | var output = _as_tf_output(); | ||||
@@ -42,7 +42,18 @@ namespace Tensorflow | |||||
public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
var names = new List<string>(); | |||||
var slices = new List<string>(); | |||||
var dtypes = new List<TF_DataType>(); | |||||
foreach (var saveable in saveables) | |||||
foreach (var spec in saveable.specs) | |||||
{ | |||||
names.Add(spec.name); | |||||
slices.Add(spec.slice_spec); | |||||
dtypes.Add(spec.dtype); | |||||
} | |||||
return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||||
} | } | ||||
public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | ||||
@@ -83,6 +94,9 @@ namespace Tensorflow | |||||
{ | { | ||||
if (build_save) | if (build_save) | ||||
_AddSaveOps(filename_tensor, saveables); | _AddSaveOps(filename_tensor, saveables); | ||||
if (build_restore) | |||||
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||||
} | } | ||||
}); | }); | ||||
@@ -94,5 +108,42 @@ namespace Tensorflow | |||||
var save = save_op(filename_tensor, saveables); | var save = save_op(filename_tensor, saveables); | ||||
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | ||||
} | } | ||||
/// <summary> | |||||
/// Add operations to restore saveables. | |||||
/// </summary> | |||||
/// <param name="filename_tensor"></param> | |||||
/// <param name="saveables"></param> | |||||
/// <param name="restore_sequentially"></param> | |||||
/// <param name="reshape"></param> | |||||
/// <param name="preferred_shard"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns>An Operation that restores the variables.</returns> | |||||
public Operation _AddRestoreOps(Tensor filename_tensor, | |||||
SaveableObject[] saveables, | |||||
bool restore_sequentially, | |||||
bool reshape, | |||||
int preferred_shard = -1, | |||||
string name = "restore_all") | |||||
{ | |||||
var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially); | |||||
var assign_ops = new List<Tensor>(); | |||||
int idx = 0; | |||||
foreach(var saveable in saveables) | |||||
{ | |||||
List<TensorShape> shapes = null; | |||||
if (reshape) | |||||
{ | |||||
throw new NotImplementedException("_AddRestoreOps"); | |||||
} | |||||
var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); | |||||
idx += saveable.specs.Length; | |||||
assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray())); | |||||
} | |||||
return control_flow_ops.group(assign_ops.ToArray(), name: name); | |||||
} | |||||
} | } | ||||
} | } |
@@ -27,5 +27,13 @@ namespace Tensorflow | |||||
this.specs = specs; | this.specs = specs; | ||||
this.name = name; | this.name = name; | ||||
} | } | ||||
public virtual Tensor restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null) | |||||
{ | |||||
var restored_tensor = restored_tensors[0]; | |||||
return gen_state_ops.assign(op, | |||||
restored_tensor, | |||||
validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined()); | |||||
} | |||||
} | } | ||||
} | } |
@@ -186,7 +186,7 @@ namespace Tensorflow | |||||
/// the assignment has completed. | /// the assignment has completed. | ||||
/// </returns> | /// </returns> | ||||
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | ||||
where T : IReturnTensorOrOperation | |||||
where T : ITensorOrOperation | |||||
{ | { | ||||
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | ||||
if (read_value) | if (read_value) | ||||