From b680700dae396dae7600ad2d3ccfcc1cf959a7ca Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 28 Nov 2019 09:09:50 -0600 Subject: [PATCH] fix import_meta_graph without VariableV1 bug. #453 --- .../Framework/meta_graph.cs | 16 ++++++++++++++++ .../Operations/Operation.Control.cs | 4 ---- .../Training/Saving/Saver.cs | 10 +++++----- .../Saving/checkpoint_management.py.cs | 19 +++++++++++++++---- .../Training/Saving/saver.py.cs | 4 +--- 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 6182a3a9..c3cb62cd 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -268,6 +268,22 @@ namespace Tensorflow switch (graph.get_collection(key)) { + case List collection_list: + col_def.BytesList = new Types.BytesList(); + foreach (var x in collection_list) + { + if(x is RefVariable x_ref_var) + { + var proto = x_ref_var.to_proto(export_scope); + col_def.BytesList.Value.Add(proto.ToByteString()); + } + else + { + Console.WriteLine($"Can't find to_proto method for type {x.GetType().Name}"); + } + } + + break; case List collection_list: col_def.BytesList = new Types.BytesList(); foreach (var x in collection_list) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index c9ae7071..ba7b0829 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -52,10 +52,6 @@ namespace Tensorflow public void _set_control_flow_context(ControlFlowContext ctx) { - if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) - { - - } _control_flow_context = ctx; } diff --git a/src/TensorFlowNET.Core/Training/Saving/Saver.cs b/src/TensorFlowNET.Core/Training/Saving/Saver.cs index 3f72a438..9cb4af10 100644 --- a/src/TensorFlowNET.Core/Training/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Training/Saving/Saver.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using System.IO; @@ -170,7 +171,7 @@ namespace Tensorflow { if (string.IsNullOrEmpty(latest_filename)) latest_filename = "checkpoint"; - object model_checkpoint_path = ""; + NDArray[] model_checkpoint_path = null; string checkpoint_file = ""; if (global_step > 0) @@ -183,15 +184,14 @@ namespace Tensorflow if (!_is_empty) { model_checkpoint_path = sess.run(_saver_def.SaveTensorName, - new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) - ); + (_saver_def.FilenameTensorName, checkpoint_file)); if (write_state) { - _RecordLastCheckpoint(model_checkpoint_path.ToString()); + _RecordLastCheckpoint(model_checkpoint_path[0].ToString()); checkpoint_management.update_checkpoint_state_internal( save_dir: save_path_parent, - model_checkpoint_path: model_checkpoint_path.ToString(), + model_checkpoint_path: model_checkpoint_path[0].ToString(), all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), latest_filename: latest_filename, save_relative_paths: _save_relative_paths); diff --git a/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs index 47f64b91..5f1cfe8c 100644 --- a/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs @@ -44,8 +44,7 @@ namespace Tensorflow float? last_preserved_timestamp = null ) { - CheckpointState ckpt = null; - + CheckpointState ckpt = null; // Writes the "checkpoint" file for the coordinator for later restoration. string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); if (save_relative_paths) @@ -65,7 +64,12 @@ namespace Tensorflow throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + "checkpoint state. Please use a different save path."); - File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); + // File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); + File.WriteAllLines(coord_checkpoint_filename, new[] + { + $"model_checkpoint_path: \"{ckpt.ModelCheckpointPath}\"", + $"all_model_checkpoint_paths: \"{ckpt.AllModelCheckpointPaths[0]}\"", + }); } /// @@ -98,7 +102,14 @@ namespace Tensorflow all_model_checkpoint_paths.Add(model_checkpoint_path); // Relative paths need to be rewritten to be relative to the "save_dir" - // if model_checkpoint_path already contains "save_dir". + if (model_checkpoint_path.StartsWith(save_dir)) + { + model_checkpoint_path = model_checkpoint_path.Substring(save_dir.Length + 1); + all_model_checkpoint_paths = all_model_checkpoint_paths + .Select(x => x.Substring(save_dir.Length + 1)) + .ToList(); + } + var coord_checkpoint_proto = new CheckpointState() { diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs index 49c9bfc5..2b75947b 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs @@ -29,14 +29,12 @@ namespace Tensorflow { var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); - var meta = meta_graph.import_scoped_meta_graph_with_return_elements( + var (imported_vars, imported_return_elements) = meta_graph.import_scoped_meta_graph_with_return_elements( meta_graph_def, clear_devices: clear_devices, import_scope: import_scope, return_elements: return_elements); - var (imported_vars, imported_return_elements) = meta; - var saver = _create_saver_from_imported_meta_graph( meta_graph_def, import_scope, imported_vars);