Browse Source

fix import_meta_graph without VariableV1 bug. #453

tags/v0.13
Oceania2018 5 years ago
parent
commit
b680700dae
5 changed files with 37 additions and 16 deletions
  1. +16
    -0
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  2. +0
    -4
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  3. +5
    -5
      src/TensorFlowNET.Core/Training/Saving/Saver.cs
  4. +15
    -4
      src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs
  5. +1
    -3
      src/TensorFlowNET.Core/Training/Saving/saver.py.cs

+ 16
- 0
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -268,6 +268,22 @@ namespace Tensorflow

switch (graph.get_collection(key))
{
case List<VariableV1> collection_list:
col_def.BytesList = new Types.BytesList();
foreach (var x in collection_list)
{
if(x is RefVariable x_ref_var)
{
var proto = x_ref_var.to_proto(export_scope);
col_def.BytesList.Value.Add(proto.ToByteString());
}
else
{
Console.WriteLine($"Can't find to_proto method for type {x.GetType().Name}");
}
}

break;
case List<RefVariable> collection_list:
col_def.BytesList = new Types.BytesList();
foreach (var x in collection_list)


+ 0
- 4
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -52,10 +52,6 @@ namespace Tensorflow
public void _set_control_flow_context(ControlFlowContext ctx)
{
if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc"))
{
}
_control_flow_context = ctx;
}


+ 5
- 5
src/TensorFlowNET.Core/Training/Saving/Saver.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
@@ -170,7 +171,7 @@ namespace Tensorflow
{
if (string.IsNullOrEmpty(latest_filename))
latest_filename = "checkpoint";
object model_checkpoint_path = "";
NDArray[] model_checkpoint_path = null;
string checkpoint_file = "";

if (global_step > 0)
@@ -183,15 +184,14 @@ namespace Tensorflow
if (!_is_empty)
{
model_checkpoint_path = sess.run(_saver_def.SaveTensorName,
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
);
(_saver_def.FilenameTensorName, checkpoint_file));

if (write_state)
{
_RecordLastCheckpoint(model_checkpoint_path.ToString());
_RecordLastCheckpoint(model_checkpoint_path[0].ToString());
checkpoint_management.update_checkpoint_state_internal(
save_dir: save_path_parent,
model_checkpoint_path: model_checkpoint_path.ToString(),
model_checkpoint_path: model_checkpoint_path[0].ToString(),
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(),
latest_filename: latest_filename,
save_relative_paths: _save_relative_paths);


+ 15
- 4
src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs View File

@@ -44,8 +44,7 @@ namespace Tensorflow
float? last_preserved_timestamp = null
)
{
CheckpointState ckpt = null;

CheckpointState ckpt = null;
// Writes the "checkpoint" file for the coordinator for later restoration.
string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename);
if (save_relative_paths)
@@ -65,7 +64,12 @@ namespace Tensorflow
throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " +
"checkpoint state. Please use a different save path.");

File.WriteAllText(coord_checkpoint_filename, ckpt.ToString());
// File.WriteAllText(coord_checkpoint_filename, ckpt.ToString());
File.WriteAllLines(coord_checkpoint_filename, new[]
{
$"model_checkpoint_path: \"{ckpt.ModelCheckpointPath}\"",
$"all_model_checkpoint_paths: \"{ckpt.AllModelCheckpointPaths[0]}\"",
});
}

/// <summary>
@@ -98,7 +102,14 @@ namespace Tensorflow
all_model_checkpoint_paths.Add(model_checkpoint_path);

// Relative paths need to be rewritten to be relative to the "save_dir"
// if model_checkpoint_path already contains "save_dir".
if (model_checkpoint_path.StartsWith(save_dir))
{
model_checkpoint_path = model_checkpoint_path.Substring(save_dir.Length + 1);
all_model_checkpoint_paths = all_model_checkpoint_paths
.Select(x => x.Substring(save_dir.Length + 1))
.ToList();
}

var coord_checkpoint_proto = new CheckpointState()
{


+ 1
- 3
src/TensorFlowNET.Core/Training/Saving/saver.py.cs View File

@@ -29,14 +29,12 @@ namespace Tensorflow
{
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);

var meta = meta_graph.import_scoped_meta_graph_with_return_elements(
var (imported_vars, imported_return_elements) = meta_graph.import_scoped_meta_graph_with_return_elements(
meta_graph_def,
clear_devices: clear_devices,
import_scope: import_scope,
return_elements: return_elements);

var (imported_vars, imported_return_elements) = meta;

var saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars);



Loading…
Cancel
Save