Browse Source

Saver is inprogress.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
d8f8a7f5bf
19 changed files with 199 additions and 55 deletions
  1. +4
    -4
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +8
    -8
      src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
  4. +1
    -0
      src/TensorFlowNET.Core/ITensorOrOperation.cs
  5. +10
    -2
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  6. +8
    -2
      src/TensorFlowNET.Core/Operations/Operation.cs
  7. +1
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  8. +1
    -10
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  9. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  11. +7
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  12. +32
    -8
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
  15. +65
    -5
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  16. +3
    -9
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  17. +21
    -0
      src/TensorFlowNET.Core/ops.py.cs
  18. +7
    -0
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs
  19. +18
    -0
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 4
- 4
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -28,9 +28,9 @@ namespace Tensorflow
/// </summary>
/// <param name="input_ops">The data input ops for an op to be created.</param>
/// <returns>A list of control inputs for the op to be created.</returns>
private Operation[] _control_dependencies_for_inputs(Operation[] input_ops)
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
{
Operation[] ret = new Operation[0];
var ret = new ITensorOrOperation[0];

foreach(var controller in _control_dependencies_stack)
{
@@ -54,12 +54,12 @@ namespace Tensorflow
return ret;
}

public _ControlDependenciesController control_dependencies(Operation[] control_inputs)
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
{
if (control_inputs == null)
return new _ControlDependenciesController(this, null);

var control_ops = new List<Operation>();
var control_ops = new List<ITensorOrOperation>();
foreach (var c in control_inputs)
{
control_ops.Add(c);


+ 5
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -298,6 +298,11 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public string[] get_all_collection_keys()
{
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
}

public object get_collection(string name, string scope = "")
{
return _collections.ContainsKey(name) ? _collections[name] : null;


+ 8
- 8
src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs View File

@@ -11,20 +11,20 @@ namespace Tensorflow
public class _ControlDependenciesController : IPython
{
private Graph _graph;
private List<Operation> _control_inputs_val;
private List<Operation> _seen_nodes;
private List<ITensorOrOperation> _control_inputs_val;
private List<ITensorOrOperation> _seen_nodes;
private Queue<_ControlDependenciesController> _old_stack;
private bool _new_stack;
private Context _old_control_flow_context;

public Operation[] control_inputs => _control_inputs_val.ToArray();
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();

public _ControlDependenciesController(Graph graph, List<Operation> control_inputs)
public _ControlDependenciesController(Graph graph, List<ITensorOrOperation> control_inputs)
{
_graph = graph;
if (control_inputs == null)
{
_control_inputs_val = new List<Operation>();
_control_inputs_val = new List<ITensorOrOperation>();
_new_stack = true;
}
else
@@ -33,15 +33,15 @@ namespace Tensorflow
_new_stack = false;
}

_seen_nodes = new List<Operation>();
_seen_nodes = new List<ITensorOrOperation>();
}

public void add_op(Operation op)
public void add_op(ITensorOrOperation op)
{
_seen_nodes.Add(op);
}

public bool op_in_group(Operation op)
public bool op_in_group(ITensorOrOperation op)
{
return _seen_nodes.Contains(op);
}


+ 1
- 0
src/TensorFlowNET.Core/ITensorOrOperation.cs View File

@@ -11,5 +11,6 @@ namespace Tensorflow
public interface ITensorOrOperation
{
string Device { get; }
Operation op { get; }
}
}

+ 10
- 2
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -107,7 +107,9 @@ namespace Tensorflow

values = ops.internal_convert_to_tensor(values,
name: input_name,
as_ref: input_arg.IsRef);
dtype: dtype,
as_ref: input_arg.IsRef,
preferred_dtype: default_dtype);

//if (!String.IsNullOrEmpty(input_arg.TypeAttr))
//attrs[input_arg.TypeAttr] = values.dtype;
@@ -163,14 +165,20 @@ namespace Tensorflow

foreach (var arg in op_def.OutputArg)
{
types = new List<TF_DataType>();
if (!string.IsNullOrEmpty(arg.NumberAttr))
{

}
else if (!string.IsNullOrEmpty(arg.TypeAttr))
{
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type };
}

if (arg.IsRef)
types = types.Select(x => x.as_ref()).ToList();

output_types.AddRange(types);
}

// Add Op to graph


+ 8
- 2
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -16,6 +16,7 @@ namespace Tensorflow
private int _id_value;

public string type => OpType;
public Operation op => this;

private Status status = new Status();

@@ -75,7 +76,7 @@ namespace Tensorflow
/// </param>
/// <param name="original_op"></param>
/// <param name="op_def"></param>
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
Graph = g;

@@ -120,6 +121,11 @@ namespace Tensorflow
_control_flow_post_processing();
}

public void run(FeedItem[] feed_dict = null, Session session = null)
{
ops._run_using_default_session(this, feed_dict, Graph, session);
}

private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
{
var grouped_inputs = new List<object>();
@@ -204,7 +210,7 @@ namespace Tensorflow

public override string ToString()
{
return _handle == IntPtr.Zero ? "Undefined" : $"'{Name}' type={OpType}";
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{Name}' type={OpType}";
}

public static implicit operator Operation(IntPtr handle) => new Operation(handle);


+ 1
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -28,10 +28,7 @@ namespace Tensorflow
{
var dev = ops_on_device.Keys.First();
var deps = ops_on_device.Values.First();
if (typeof(T).Name == "Operation")
return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name);
else
throw new NotImplementedException("control_flow_ops.group");
return _GroupControlDeps(dev, deps.Select(x => x.op).ToArray(), name);
}

// 2-level tree. The root node is the returned NoOp node.


+ 1
- 10
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -35,17 +35,8 @@ namespace Tensorflow
c_api.TF_DeleteSessionOptions(opts);
}

public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null)
{
return _run(fetches, feed_dict);
}

public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null)
{
return _run(fetches, feed_dict);
}

public virtual NDArray run(Operation fetches, FeedItem[] feed_dict = null)
public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null)
{
return _run(fetches, feed_dict);
}


+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -30,6 +30,11 @@ namespace Tensorflow
return tensor._handle;
}

public static implicit operator Operation(Tensor tensor)
{
return tensor.op;
}

public static implicit operator Tensor(IntPtr handle)
{
return new Tensor(handle);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -261,7 +261,7 @@ namespace Tensorflow
}
}

return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
}

public void Dispose()


+ 7
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -83,6 +83,13 @@ namespace Tensorflow
type;
}

public static TF_DataType as_ref(this TF_DataType type)
{
return (int)type < 100 ?
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) :
type;
}

public static bool is_complex(this TF_DataType type)
{
return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128;


+ 32
- 8
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -7,9 +7,9 @@ namespace Tensorflow
{
public class BaseSaverBuilder
{
protected int _write_version;
protected SaverDef.Types.CheckpointFormatVersion _write_version;

public BaseSaverBuilder(int write_version = 2)
public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2)
{
_write_version = write_version;
}
@@ -30,7 +30,7 @@ namespace Tensorflow
}
}

if (_write_version == 2)
if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2)
{
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
}
@@ -60,7 +60,7 @@ namespace Tensorflow
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
double keep_checkpoint_every_n_hours = 10000,
float keep_checkpoint_every_n_hours = 10000,
string name = "",
bool restore_sequentially = false,
string filename = "model",
@@ -76,7 +76,10 @@ namespace Tensorflow
if (max_to_keep < 0)
max_to_keep = 0;

Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
Tensor save_tensor = null;
Operation restore_op = null;

return Python.with<ops.name_scope, SaverDef>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
{
name = scope;

@@ -93,14 +96,35 @@ namespace Tensorflow
else
{
if (build_save)
_AddSaveOps(filename_tensor, saveables);
save_tensor = _AddSaveOps(filename_tensor, saveables);

if (build_restore)
_AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape);
restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape);
}

var graph = ops.get_default_graph();
var check_collection_list = graph.get_all_collection_keys();
foreach (var collection_type in check_collection_list)
{
foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>)
{

}
}

return new SaverDef()
{
FilenameTensorName = filename_tensor.name,
SaveTensorName = save_tensor.name,
RestoreOpName = restore_op.Name,
MaxToKeep = max_to_keep,
Sharded = sharded,
KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours,
Version = _write_version
};
});

throw new NotImplementedException("");
}

public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables)


+ 1
- 1
src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs View File

@@ -6,7 +6,7 @@ namespace Tensorflow
{
public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder
{
public BulkSaverBuilder(int write_version = 2) : base(write_version)
public BulkSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) : base(write_version)
{

}


+ 1
- 1
src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
double keep_checkpoint_every_n_hours = 10000,
float keep_checkpoint_every_n_hours = 10000,
string name = "",
bool restore_sequentially = false,
string filename = "model",


+ 65
- 5
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow
@@ -13,30 +14,33 @@ namespace Tensorflow
private bool _reshape;
private bool _sharded;
private int _max_to_keep;
private double _keep_checkpoint_every_n_hours;
private float _keep_checkpoint_every_n_hours;
private string _name;
private bool _restore_sequentially;
private SaverDef _saver_def;
private ISaverBuilder _builder;
private bool _allow_empty;
private bool _is_built;
private int _write_version;
private SaverDef.Types.CheckpointFormatVersion _write_version;
private bool _pad_step_number;
private string _filename;
private bool _is_empty;
private float _next_checkpoint_time;
private bool _save_relative_paths;
private bool? _object_restore_saver;

public Saver(RefVariable[] var_list = null,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
double keep_checkpoint_every_n_hours = 10000,
float keep_checkpoint_every_n_hours = 10000,
string name = "",
bool restore_sequentially = false,
SaverDef saver_def = null,
ISaverBuilder builder = null,
bool defer_build = false,
bool allow_empty = false,
int write_version = 2,
SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2,
bool pad_step_number = false,
bool save_relative_paths = false,
string filename = "")
@@ -56,6 +60,14 @@ namespace Tensorflow

if (!defer_build)
build();
if(_saver_def != null)
{
_check_saver_def();
_write_version = _saver_def.Version;
}

_save_relative_paths = save_relative_paths;
_object_restore_saver = null;
}

public void build()
@@ -106,8 +118,56 @@ namespace Tensorflow
{
throw new NotImplementedException("");
}

_check_saver_def();

_next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600;
}

private void _check_saver_def()
{
if (!tf.context.executing_eagerly())
{
if (string.IsNullOrEmpty(_saver_def.SaveTensorName))
throw new ValueError($"saver_def must specify the save_tensor_name: {_saver_def}");
if (string.IsNullOrEmpty(_saver_def.RestoreOpName))
throw new ValueError($"saver_def must specify the restore_op_name: {_saver_def}");
}
}

public string save(Session sess,
string save_path,
string global_step = "",
string meta_graph_suffix = "meta",
bool write_meta_graph = true,
bool write_state = true,
bool strip_default_attrs = false)
{
string latest_filename = "checkpoint";
string model_checkpoint_path = "";
string checkpoint_file = "";

if (!string.IsNullOrEmpty(global_step))
{

}
else
{
checkpoint_file = save_path;
}

var save_path_parent = Path.GetDirectoryName(save_path);

if (!_is_empty)
{
/*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
});*/
}

throw new NotImplementedException("");

return model_checkpoint_path;
}
}
}

+ 3
- 9
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -185,18 +185,12 @@ namespace Tensorflow
/// A `Tensor` that will hold the new value of this variable after
/// the assignment has completed.
/// </returns>
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true)
where T : ITensorOrOperation
public ITensorOrOperation assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true)
{
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
if (read_value)
return (T)Convert.ChangeType(assign, typeof(T));
return (T)Convert.ChangeType(assign.op, typeof(T));
}

public Tensor assign(Tensor value, bool use_locking = false, string name = "")
{
return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
return assign;
return assign.op;
}

public override string ToString()


+ 21
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -292,6 +292,27 @@ namespace Tensorflow
return tf.Session();
}

public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)
{
if (session == null)
{
session = get_default_session();
if (session == null)
throw new ValueError("Cannot execute operation using `run()`: No default " +
"session is registered. Use `with " +
"sess.as_default():` or pass an explicit session to " +
"`run(session=sess)`");
}

if (session.graph != graph)
throw new ValueError("Cannot use the default session to execute operation: " +
"the operation's graph is different from the " +
"session's graph. Pass an explicit session to " +
"run(session=sess).");

session.run(operation, feed_dict);
}

public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op)
{
if (op.inputs == null) return null;


+ 7
- 0
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -27,6 +27,13 @@ namespace TensorFlowNET.UnitTest
with<Session>(tf.Session(), sess =>
{
sess.run(init_op);
// o some work with the model.
inc_v1.op.run();
dec_v2.op.run();

// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
});
}
}


+ 18
- 0
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -46,6 +46,24 @@ namespace TensorFlowNET.UnitTest
}
}

[TestMethod]
public void Assign()
{
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);

var inc_v1 = v1.assign(v1 + 1.0f);

// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();

with<Session>(tf.Session(), sess =>
{
sess.run(init_op);
// o some work with the model.
inc_v1.op.run();
});
}

/// <summary>
/// https://databricks.com/tensorflow/variables
/// </summary>


Loading…
Cancel
Save