From 9b9a25c712868edfd9904147a71d78c1b684d457 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 14 Sep 2019 09:05:48 -0500 Subject: [PATCH] _create_and_assert_global_step --- src/TensorFlowNET.Core/APIs/tf.train.cs | 6 +++ src/TensorFlowNET.Core/APIs/tf.variable.cs | 4 +- .../Estimators/Estimator.cs | 36 ++++++++++--- src/TensorFlowNET.Core/Graphs/Graph.cs | 4 ++ .../Operations/array_ops.py.cs | 2 + src/TensorFlowNET.Core/Tensors/TF_DataType.cs | 1 + src/TensorFlowNET.Core/Tensors/dtypes.cs | 3 +- src/TensorFlowNET.Core/Train/Optimizer.cs | 4 +- .../Train/Saving/checkpoint_management.py.cs | 20 +++++++- src/TensorFlowNET.Core/Train/TrainingUtil.cs | 51 +++++++++++++++++++ .../Variables/VariableScope.cs | 3 ++ .../Variables/_VariableStore.cs | 6 +++ .../Variables/variable_scope.py.cs | 2 + 13 files changed, 130 insertions(+), 12 deletions(-) create mode 100644 src/TensorFlowNET.Core/Train/TrainingUtil.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 54e4aea1..d6de08f4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -25,6 +25,12 @@ namespace Tensorflow public class train_internal { + public RefVariable create_global_step(Graph graph) + => TrainingUtil.create_global_step(graph); + + public RefVariable get_global_step(Graph graph) + => TrainingUtil.get_global_step(graph); + public Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index b3c5bf43..ef6d65c9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -46,6 +46,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, object initializer = null, // IInitializer or Tensor bool? trainable = null, + List collections = null, bool? use_resource = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -60,7 +61,8 @@ namespace Tensorflow use_resource: use_resource, validate_shape: validate_shape, initializer: initializer, - trainable: trainable); + trainable: trainable, + collections: collections); } } } diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs index 2c9ae7d9..2570206f 100644 --- a/src/TensorFlowNET.Core/Estimators/Estimator.cs +++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.IO; using System.Text; using static Tensorflow.Binding; @@ -34,32 +36,52 @@ namespace Tensorflow.Estimators if(max_steps > 0) { var start_step = _load_global_step_from_checkpoint_dir(_model_dir); + if (max_steps <= start_step) + { + Console.WriteLine("Skipping training since max_steps has already saved."); + return this; + } } - _train_model(); + _train_model(input_fn); throw new NotImplementedException(""); } private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) { - var cp = tf.train.latest_checkpoint(checkpoint_dir); + // var cp = tf.train.latest_checkpoint(checkpoint_dir); + // should use NewCheckpointReader (not implemented) + var cp = tf.train.get_checkpoint_state(checkpoint_dir); - return 0; + return cp.AllModelCheckpointPaths.Count - 1; } - private void _train_model() + private void _train_model(Action input_fn) { - _train_model_default(); + _train_model_default(input_fn); } - private void _train_model_default() + private void _train_model_default(Action input_fn) { using (var g = tf.Graph().as_default()) { - + var global_step_tensor = _create_and_assert_global_step(g); } } + private Tensor _create_and_assert_global_step(Graph graph) + { + var step = _create_global_step(graph); + Debug.Assert(step == tf.train.get_global_step(graph)); + Debug.Assert(step.dtype.is_integer()); + return step; + } + + private RefVariable _create_global_step(Graph graph) + { + return tf.train.create_global_step(graph); + } + public void __init__() { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index fa806156..013365d2 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -175,6 +175,10 @@ namespace Tensorflow if (_nodes_by_name.ContainsKey(op_name)) return _nodes_by_name[op_name].outputs[out_n]; + else + throw new KeyError($"The name {name} refers to a Tensor which does not " + + $"exist. The operation, {op_name}, does not exist in the " + + "graph."); } else if (!name.Contains(":") & allow_operation) { diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index cf38a8c3..83a469dc 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -54,6 +54,8 @@ namespace Tensorflow return _constant_if_small(0.0D, shape, dtype, name); case TF_DataType.TF_FLOAT: return _constant_if_small(0.0F, shape, dtype, name); + case TF_DataType.TF_INT64: + return _constant_if_small(0l, shape, dtype, name); case TF_DataType.TF_INT32: return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT8: diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 97496df1..c916b321 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -35,5 +35,6 @@ DtFloatRef = 101, // DT_FLOAT_REF DtDoubleRef = 102, // DT_DOUBLE_REF DtInt32Ref = 103, // DT_INT32_REF + DtInt64Ref = 109 // DT_INT64_REF } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 41d9009c..90b1b80d 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -246,7 +246,8 @@ namespace Tensorflow public static bool is_integer(this TF_DataType type) { return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 || - type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64; + type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64 || + type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; } public static bool is_floating(this TF_DataType type) diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 90c8674d..0d6c304a 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -249,7 +249,9 @@ namespace Tensorflow { _maybe_initialize_trackable(); v = variable_scope.default_variable_creator( - initial_value, name: name, trainable: false, + initial_value, + name: name, + trainable: false, use_resource: resource_variable_ops.is_resource_variable( colocate_with)); diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs index c23368c1..47f64b91 100644 --- a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs @@ -174,8 +174,24 @@ namespace Tensorflow var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); if (File.Exists(coord_checkpoint_filename)) { - var file_content = File.ReadAllBytes(coord_checkpoint_filename); - var ckpt = CheckpointState.Parser.ParseFrom(file_content); + var file_content = File.ReadAllLines(coord_checkpoint_filename); + // https://github.com/protocolbuffers/protobuf/issues/6654 + // var ckpt = CheckpointState.Parser.ParseFrom(file_content); + var ckpt = new CheckpointState(); + var field = CheckpointState.Descriptor.FindFieldByName("model_checkpoint_path"); + ckpt.ModelCheckpointPath = file_content.FirstOrDefault(x => x.StartsWith(field.Name + ":")).Substring(field.Name.Length + 2); + // remove first and last quote. + ckpt.ModelCheckpointPath = ckpt.ModelCheckpointPath.Substring(1, ckpt.ModelCheckpointPath.Length - 2); + + field = CheckpointState.Descriptor.FindFieldByName("all_model_checkpoint_paths"); + file_content.Where(x => x.StartsWith(field.Name + ":")) + .ToList() + .ForEach(x => + { + string value = x.Substring(field.Name.Length + 2); + ckpt.AllModelCheckpointPaths.Add(value.Substring(1, value.Length - 2)); + }); + if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); // For relative model_checkpoint_path and all_model_checkpoint_paths, diff --git a/src/TensorFlowNET.Core/Train/TrainingUtil.cs b/src/TensorFlowNET.Core/Train/TrainingUtil.cs new file mode 100644 index 00000000..1b6f7f81 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/TrainingUtil.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class TrainingUtil + { + public static RefVariable create_global_step(Graph graph) + { + graph = graph ?? ops.get_default_graph(); + if (get_global_step(graph) != null) + throw new ValueError("global_step already exists."); + + // Create in proper graph and base name_scope. + var g = graph.as_default(); + g.name_scope(null); + var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, + initializer: tf.zeros_initializer, + trainable: false, + aggregation: VariableAggregation.OnlyFirstReplica, + collections: new List { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP }); + return v; + } + + public static RefVariable get_global_step(Graph graph) + { + graph = graph ?? ops.get_default_graph(); + RefVariable global_step_tensor = null; + var global_step_tensors = graph.get_collection(tf.GraphKeys.GLOBAL_STEP); + if (global_step_tensors.Count == 1) + { + global_step_tensor = global_step_tensors[0]; + } + else + { + try + { + global_step_tensor = graph.get_tensor_by_name("global_step:0"); + } + catch (KeyError) + { + return null; + } + } + + return global_step_tensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index afc221c8..ad7750a1 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; using static Tensorflow.Binding; namespace Tensorflow @@ -50,6 +51,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, object initializer = null, // IInitializer or Tensor bool? trainable = null, + List collections = null, bool? use_resource = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -67,6 +69,7 @@ namespace Tensorflow initializer: initializer, reuse: resue, trainable: trainable, + collections: collections, synchronization: synchronization, aggregation: aggregation); }); diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 8957568e..d0fbf161 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -42,6 +42,7 @@ namespace Tensorflow object initializer = null, // IInitializer or Tensor bool? reuse = null, bool? trainable = null, + List collections = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) @@ -54,6 +55,7 @@ namespace Tensorflow dtype: dtype, initializer: initializer, trainable: trainable, + collections: collections, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); @@ -64,6 +66,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.TF_FLOAT, object initializer = null, bool? trainable = null, + List collections = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) @@ -77,6 +80,7 @@ namespace Tensorflow dtype: dtype, initializer: init, trainable: trainable, + collections: collections, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); @@ -112,6 +116,7 @@ namespace Tensorflow IInitializer initializer = null, bool reuse = false, bool? trainable = null, + List collections = null, bool validate_shape = false, bool? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -157,6 +162,7 @@ namespace Tensorflow v = variable_scope.default_variable_creator(init_val, name: name, trainable: trainable, + collections: collections, dtype: variable_dtype, validate_shape: validate_shape, synchronization: synchronization, diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 6bc83052..4f357b12 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -175,6 +175,7 @@ namespace Tensorflow public static RefVariable default_variable_creator(object initial_value, string name = null, bool? trainable = null, + List collections = null, TF_DataType dtype = TF_DataType.DtInvalid, bool validate_shape = false, bool ? use_resource = null, @@ -199,6 +200,7 @@ namespace Tensorflow return new RefVariable(initial_value, trainable: trainable.Value, validate_shape: validate_shape, + collections: collections, name: name, dtype: dtype); }