From d8f8a7f5bfafee1f00808f92046729c49dc47a18 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 9 Feb 2019 11:50:19 -0600 Subject: [PATCH] Saver is inprogress. --- .../Graphs/Graph.Control.cs | 8 +-- src/TensorFlowNET.Core/Graphs/Graph.cs | 5 ++ .../Graphs/_ControlDependenciesController.cs | 16 ++--- src/TensorFlowNET.Core/ITensorOrOperation.cs | 1 + .../Operations/OpDefLibrary.cs | 12 +++- .../Operations/Operation.cs | 10 ++- .../Operations/control_flow_ops.py.cs | 5 +- .../Sessions/BaseSession.cs | 11 +-- .../Tensors/Tensor.Implicit.cs | 5 ++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- src/TensorFlowNET.Core/Tensors/dtypes.cs | 7 ++ .../Train/Saving/BaseSaverBuilder.cs | 40 ++++++++--- .../Train/Saving/BulkSaverBuilder.cs | 2 +- .../Train/Saving/ISaverBuilder.cs | 2 +- src/TensorFlowNET.Core/Train/Saving/Saver.cs | 70 +++++++++++++++++-- .../Variables/RefVariable.cs | 12 +--- src/TensorFlowNET.Core/ops.py.cs | 21 ++++++ test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 7 ++ test/TensorFlowNET.UnitTest/VariableTest.cs | 18 +++++ 19 files changed, 199 insertions(+), 55 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index af92e905..c9e3be84 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -28,9 +28,9 @@ namespace Tensorflow /// /// The data input ops for an op to be created. /// A list of control inputs for the op to be created. - private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) + private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) { - Operation[] ret = new Operation[0]; + var ret = new ITensorOrOperation[0]; foreach(var controller in _control_dependencies_stack) { @@ -54,12 +54,12 @@ namespace Tensorflow return ret; } - public _ControlDependenciesController control_dependencies(Operation[] control_inputs) + public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) { if (control_inputs == null) return new _ControlDependenciesController(this, null); - var control_ops = new List(); + var control_ops = new List(); foreach (var c in control_inputs) { control_ops.Add(c); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index aa3eb26e..c9d93182 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -298,6 +298,11 @@ namespace Tensorflow return _nodes_by_name.Values.Select(x => x).ToArray(); } + public string[] get_all_collection_keys() + { + return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); + } + public object get_collection(string name, string scope = "") { return _collections.ContainsKey(name) ? _collections[name] : null; diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index 08302cc1..f1ddcb44 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -11,20 +11,20 @@ namespace Tensorflow public class _ControlDependenciesController : IPython { private Graph _graph; - private List _control_inputs_val; - private List _seen_nodes; + private List _control_inputs_val; + private List _seen_nodes; private Queue<_ControlDependenciesController> _old_stack; private bool _new_stack; private Context _old_control_flow_context; - public Operation[] control_inputs => _control_inputs_val.ToArray(); + public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); - public _ControlDependenciesController(Graph graph, List control_inputs) + public _ControlDependenciesController(Graph graph, List control_inputs) { _graph = graph; if (control_inputs == null) { - _control_inputs_val = new List(); + _control_inputs_val = new List(); _new_stack = true; } else @@ -33,15 +33,15 @@ namespace Tensorflow _new_stack = false; } - _seen_nodes = new List(); + _seen_nodes = new List(); } - public void add_op(Operation op) + public void add_op(ITensorOrOperation op) { _seen_nodes.Add(op); } - public bool op_in_group(Operation op) + public bool op_in_group(ITensorOrOperation op) { return _seen_nodes.Contains(op); } diff --git a/src/TensorFlowNET.Core/ITensorOrOperation.cs b/src/TensorFlowNET.Core/ITensorOrOperation.cs index c29713b9..f12a0b02 100644 --- a/src/TensorFlowNET.Core/ITensorOrOperation.cs +++ b/src/TensorFlowNET.Core/ITensorOrOperation.cs @@ -11,5 +11,6 @@ namespace Tensorflow public interface ITensorOrOperation { string Device { get; } + Operation op { get; } } } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 682f59ec..76dd318a 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -107,7 +107,9 @@ namespace Tensorflow values = ops.internal_convert_to_tensor(values, name: input_name, - as_ref: input_arg.IsRef); + dtype: dtype, + as_ref: input_arg.IsRef, + preferred_dtype: default_dtype); //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) //attrs[input_arg.TypeAttr] = values.dtype; @@ -163,14 +165,20 @@ namespace Tensorflow foreach (var arg in op_def.OutputArg) { + types = new List(); if (!string.IsNullOrEmpty(arg.NumberAttr)) { } else if (!string.IsNullOrEmpty(arg.TypeAttr)) { - output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); + types = new List() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; } + + if (arg.IsRef) + types = types.Select(x => x.as_ref()).ToList(); + + output_types.AddRange(types); } // Add Op to graph diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e6450393..0549edf0 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -16,6 +16,7 @@ namespace Tensorflow private int _id_value; public string type => OpType; + public Operation op => this; private Status status = new Status(); @@ -75,7 +76,7 @@ namespace Tensorflow /// /// /// - public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) + public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { Graph = g; @@ -120,6 +121,11 @@ namespace Tensorflow _control_flow_post_processing(); } + public void run(FeedItem[] feed_dict = null, Session session = null) + { + ops._run_using_default_session(this, feed_dict, Graph, session); + } + private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs) { var grouped_inputs = new List(); @@ -204,7 +210,7 @@ namespace Tensorflow public override string ToString() { - return _handle == IntPtr.Zero ? "Undefined" : $"'{Name}' type={OpType}"; + return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{Name}' type={OpType}"; } public static implicit operator Operation(IntPtr handle) => new Operation(handle); diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index ebe9a0a4..5c0964ef 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -28,10 +28,7 @@ namespace Tensorflow { var dev = ops_on_device.Keys.First(); var deps = ops_on_device.Values.First(); - if (typeof(T).Name == "Operation") - return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); - else - throw new NotImplementedException("control_flow_ops.group"); + return _GroupControlDeps(dev, deps.Select(x => x.op).ToArray(), name); } // 2-level tree. The root node is the returned NoOp node. diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 13234366..82eee051 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -35,17 +35,8 @@ namespace Tensorflow c_api.TF_DeleteSessionOptions(opts); } - public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null) - { - return _run(fetches, feed_dict); - } - - public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) - { - return _run(fetches, feed_dict); - } - public virtual NDArray run(Operation fetches, FeedItem[] feed_dict = null) + public virtual NDArray run(T fetches, FeedItem[] feed_dict = null) { return _run(fetches, feed_dict); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index aed5d4a1..79a23e33 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -30,6 +30,11 @@ namespace Tensorflow return tensor._handle; } + public static implicit operator Operation(Tensor tensor) + { + return tensor.op; + } + public static implicit operator Tensor(IntPtr handle) { return new Tensor(handle); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 2f5b4542..f2674fcb 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -261,7 +261,7 @@ namespace Tensorflow } } - return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; + return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; } public void Dispose() diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index af429ee1..31bfe3e2 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -83,6 +83,13 @@ namespace Tensorflow type; } + public static TF_DataType as_ref(this TF_DataType type) + { + return (int)type < 100 ? + (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) : + type; + } + public static bool is_complex(this TF_DataType type) { return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index ccd963d6..ea7377af 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -7,9 +7,9 @@ namespace Tensorflow { public class BaseSaverBuilder { - protected int _write_version; + protected SaverDef.Types.CheckpointFormatVersion _write_version; - public BaseSaverBuilder(int write_version = 2) + public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) { _write_version = write_version; } @@ -30,7 +30,7 @@ namespace Tensorflow } } - if (_write_version == 2) + if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) { return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); } @@ -60,7 +60,7 @@ namespace Tensorflow bool reshape = false, bool sharded = false, int max_to_keep = 5, - double keep_checkpoint_every_n_hours = 10000, + float keep_checkpoint_every_n_hours = 10000, string name = "", bool restore_sequentially = false, string filename = "model", @@ -76,7 +76,10 @@ namespace Tensorflow if (max_to_keep < 0) max_to_keep = 0; - Python.with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => + Tensor save_tensor = null; + Operation restore_op = null; + + return Python.with(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => { name = scope; @@ -93,14 +96,35 @@ namespace Tensorflow else { if (build_save) - _AddSaveOps(filename_tensor, saveables); + save_tensor = _AddSaveOps(filename_tensor, saveables); if (build_restore) - _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); + restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); } + + var graph = ops.get_default_graph(); + var check_collection_list = graph.get_all_collection_keys(); + foreach (var collection_type in check_collection_list) + { + foreach (var element in graph.get_collection(collection_type) as IList) + { + + } + } + + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.Name, + MaxToKeep = max_to_keep, + Sharded = sharded, + KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours, + Version = _write_version + }; }); - throw new NotImplementedException(""); + } public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) diff --git a/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs index b99b75f0..1919ab25 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BulkSaverBuilder.cs @@ -6,7 +6,7 @@ namespace Tensorflow { public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder { - public BulkSaverBuilder(int write_version = 2) : base(write_version) + public BulkSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) : base(write_version) { } diff --git a/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs index ed69919e..9277ae9a 100644 --- a/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs @@ -14,7 +14,7 @@ namespace Tensorflow bool reshape = false, bool sharded = false, int max_to_keep = 5, - double keep_checkpoint_every_n_hours = 10000, + float keep_checkpoint_every_n_hours = 10000, string name = "", bool restore_sequentially = false, string filename = "model", diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 5e7d6333..577fba58 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.IO; using System.Text; namespace Tensorflow @@ -13,30 +14,33 @@ namespace Tensorflow private bool _reshape; private bool _sharded; private int _max_to_keep; - private double _keep_checkpoint_every_n_hours; + private float _keep_checkpoint_every_n_hours; private string _name; private bool _restore_sequentially; private SaverDef _saver_def; private ISaverBuilder _builder; private bool _allow_empty; private bool _is_built; - private int _write_version; + private SaverDef.Types.CheckpointFormatVersion _write_version; private bool _pad_step_number; private string _filename; private bool _is_empty; + private float _next_checkpoint_time; + private bool _save_relative_paths; + private bool? _object_restore_saver; public Saver(RefVariable[] var_list = null, bool reshape = false, bool sharded = false, int max_to_keep = 5, - double keep_checkpoint_every_n_hours = 10000, + float keep_checkpoint_every_n_hours = 10000, string name = "", bool restore_sequentially = false, SaverDef saver_def = null, ISaverBuilder builder = null, bool defer_build = false, bool allow_empty = false, - int write_version = 2, + SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2, bool pad_step_number = false, bool save_relative_paths = false, string filename = "") @@ -56,6 +60,14 @@ namespace Tensorflow if (!defer_build) build(); + if(_saver_def != null) + { + _check_saver_def(); + _write_version = _saver_def.Version; + } + + _save_relative_paths = save_relative_paths; + _object_restore_saver = null; } public void build() @@ -106,8 +118,56 @@ namespace Tensorflow { throw new NotImplementedException(""); } - + _check_saver_def(); + + _next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; + } + + private void _check_saver_def() + { + if (!tf.context.executing_eagerly()) + { + if (string.IsNullOrEmpty(_saver_def.SaveTensorName)) + throw new ValueError($"saver_def must specify the save_tensor_name: {_saver_def}"); + if (string.IsNullOrEmpty(_saver_def.RestoreOpName)) + throw new ValueError($"saver_def must specify the restore_op_name: {_saver_def}"); + } + } + + public string save(Session sess, + string save_path, + string global_step = "", + string meta_graph_suffix = "meta", + bool write_meta_graph = true, + bool write_state = true, + bool strip_default_attrs = false) + { + string latest_filename = "checkpoint"; + string model_checkpoint_path = ""; + string checkpoint_file = ""; + + if (!string.IsNullOrEmpty(global_step)) + { + + } + else + { + checkpoint_file = save_path; + } + + var save_path_parent = Path.GetDirectoryName(save_path); + + if (!_is_empty) + { + /*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { + new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) + });*/ + } + + throw new NotImplementedException(""); + + return model_checkpoint_path; } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index b908f657..59333f3d 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -185,18 +185,12 @@ namespace Tensorflow /// A `Tensor` that will hold the new value of this variable after /// the assignment has completed. /// - public T assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) - where T : ITensorOrOperation + public ITensorOrOperation assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) { var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); if (read_value) - return (T)Convert.ChangeType(assign, typeof(T)); - return (T)Convert.ChangeType(assign.op, typeof(T)); - } - - public Tensor assign(Tensor value, bool use_locking = false, string name = "") - { - return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); + return assign; + return assign.op; } public override string ToString() diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 15d17b56..74932bda 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -292,6 +292,27 @@ namespace Tensorflow return tf.Session(); } + public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) + { + if (session == null) + { + session = get_default_session(); + if (session == null) + throw new ValueError("Cannot execute operation using `run()`: No default " + + "session is registered. Use `with " + + "sess.as_default():` or pass an explicit session to " + + "`run(session=sess)`"); + } + + if (session.graph != graph) + throw new ValueError("Cannot use the default session to execute operation: " + + "the operation's graph is different from the " + + "session's graph. Pass an explicit session to " + + "run(session=sess)."); + + session.run(operation, feed_dict); + } + public static Func get_gradient_function(Operation op) { if (op.inputs == null) return null; diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 40775fcb..541bf893 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -27,6 +27,13 @@ namespace TensorFlowNET.UnitTest with(tf.Session(), sess => { sess.run(init_op); + // o some work with the model. + inc_v1.op.run(); + dec_v2.op.run(); + + // Save the variables to disk. + var save_path = saver.save(sess, "/tmp/model.ckpt"); + Console.WriteLine($"Model saved in path: {save_path}"); }); } } diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index fff5b277..ffa2d284 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -46,6 +46,24 @@ namespace TensorFlowNET.UnitTest } } + [TestMethod] + public void Assign() + { + var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); + + var inc_v1 = v1.assign(v1 + 1.0f); + + // Add an op to initialize the variables. + var init_op = tf.global_variables_initializer(); + + with(tf.Session(), sess => + { + sess.run(init_op); + // o some work with the model. + inc_v1.op.run(); + }); + } + /// /// https://databricks.com/tensorflow/variables ///