|
|
@@ -42,7 +42,18 @@ namespace Tensorflow |
|
|
|
|
|
|
|
public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) |
|
|
|
{ |
|
|
|
throw new NotImplementedException(); |
|
|
|
var names = new List<string>(); |
|
|
|
var slices = new List<string>(); |
|
|
|
var dtypes = new List<TF_DataType>(); |
|
|
|
foreach (var saveable in saveables) |
|
|
|
foreach (var spec in saveable.specs) |
|
|
|
{ |
|
|
|
names.Add(spec.name); |
|
|
|
slices.Add(spec.slice_spec); |
|
|
|
dtypes.Add(spec.dtype); |
|
|
|
} |
|
|
|
|
|
|
|
return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); |
|
|
|
} |
|
|
|
|
|
|
|
public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, |
|
|
@@ -83,6 +94,9 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
if (build_save) |
|
|
|
_AddSaveOps(filename_tensor, saveables); |
|
|
|
|
|
|
|
if (build_restore) |
|
|
|
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
@@ -94,5 +108,42 @@ namespace Tensorflow |
|
|
|
var save = save_op(filename_tensor, saveables); |
|
|
|
return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Add operations to restore saveables. |
|
|
|
/// </summary> |
|
|
|
/// <param name="filename_tensor"></param> |
|
|
|
/// <param name="saveables"></param> |
|
|
|
/// <param name="restore_sequentially"></param> |
|
|
|
/// <param name="reshape"></param> |
|
|
|
/// <param name="preferred_shard"></param> |
|
|
|
/// <param name="name"></param> |
|
|
|
/// <returns>An Operation that restores the variables.</returns> |
|
|
|
public Operation _AddRestoreOps(Tensor filename_tensor, |
|
|
|
SaveableObject[] saveables, |
|
|
|
bool restore_sequentially, |
|
|
|
bool reshape, |
|
|
|
int preferred_shard = -1, |
|
|
|
string name = "restore_all") |
|
|
|
{ |
|
|
|
var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially); |
|
|
|
var assign_ops = new List<Tensor>(); |
|
|
|
int idx = 0; |
|
|
|
|
|
|
|
foreach(var saveable in saveables) |
|
|
|
{ |
|
|
|
List<TensorShape> shapes = null; |
|
|
|
if (reshape) |
|
|
|
{ |
|
|
|
throw new NotImplementedException("_AddRestoreOps"); |
|
|
|
} |
|
|
|
|
|
|
|
var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); |
|
|
|
idx += saveable.specs.Length; |
|
|
|
assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray())); |
|
|
|
} |
|
|
|
|
|
|
|
return control_flow_ops.group(assign_ops.ToArray(), name: name); |
|
|
|
} |
|
|
|
} |
|
|
|
} |