@@ -268,6 +268,22 @@ namespace Tensorflow | |||||
switch (graph.get_collection(key)) | switch (graph.get_collection(key)) | ||||
{ | { | ||||
case List<VariableV1> collection_list: | |||||
col_def.BytesList = new Types.BytesList(); | |||||
foreach (var x in collection_list) | |||||
{ | |||||
if(x is RefVariable x_ref_var) | |||||
{ | |||||
var proto = x_ref_var.to_proto(export_scope); | |||||
col_def.BytesList.Value.Add(proto.ToByteString()); | |||||
} | |||||
else | |||||
{ | |||||
Console.WriteLine($"Can't find to_proto method for type {x.GetType().Name}"); | |||||
} | |||||
} | |||||
break; | |||||
case List<RefVariable> collection_list: | case List<RefVariable> collection_list: | ||||
col_def.BytesList = new Types.BytesList(); | col_def.BytesList = new Types.BytesList(); | ||||
foreach (var x in collection_list) | foreach (var x in collection_list) | ||||
@@ -52,10 +52,6 @@ namespace Tensorflow | |||||
public void _set_control_flow_context(ControlFlowContext ctx) | public void _set_control_flow_context(ControlFlowContext ctx) | ||||
{ | { | ||||
if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) | |||||
{ | |||||
} | |||||
_control_flow_context = ctx; | _control_flow_context = ctx; | ||||
} | } | ||||
@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using NumSharp; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
@@ -170,7 +171,7 @@ namespace Tensorflow | |||||
{ | { | ||||
if (string.IsNullOrEmpty(latest_filename)) | if (string.IsNullOrEmpty(latest_filename)) | ||||
latest_filename = "checkpoint"; | latest_filename = "checkpoint"; | ||||
object model_checkpoint_path = ""; | |||||
NDArray[] model_checkpoint_path = null; | |||||
string checkpoint_file = ""; | string checkpoint_file = ""; | ||||
if (global_step > 0) | if (global_step > 0) | ||||
@@ -183,15 +184,14 @@ namespace Tensorflow | |||||
if (!_is_empty) | if (!_is_empty) | ||||
{ | { | ||||
model_checkpoint_path = sess.run(_saver_def.SaveTensorName, | model_checkpoint_path = sess.run(_saver_def.SaveTensorName, | ||||
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | |||||
); | |||||
(_saver_def.FilenameTensorName, checkpoint_file)); | |||||
if (write_state) | if (write_state) | ||||
{ | { | ||||
_RecordLastCheckpoint(model_checkpoint_path.ToString()); | |||||
_RecordLastCheckpoint(model_checkpoint_path[0].ToString()); | |||||
checkpoint_management.update_checkpoint_state_internal( | checkpoint_management.update_checkpoint_state_internal( | ||||
save_dir: save_path_parent, | save_dir: save_path_parent, | ||||
model_checkpoint_path: model_checkpoint_path.ToString(), | |||||
model_checkpoint_path: model_checkpoint_path[0].ToString(), | |||||
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), | all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(), | ||||
latest_filename: latest_filename, | latest_filename: latest_filename, | ||||
save_relative_paths: _save_relative_paths); | save_relative_paths: _save_relative_paths); | ||||
@@ -44,8 +44,7 @@ namespace Tensorflow | |||||
float? last_preserved_timestamp = null | float? last_preserved_timestamp = null | ||||
) | ) | ||||
{ | { | ||||
CheckpointState ckpt = null; | |||||
CheckpointState ckpt = null; | |||||
// Writes the "checkpoint" file for the coordinator for later restoration. | // Writes the "checkpoint" file for the coordinator for later restoration. | ||||
string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); | string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename); | ||||
if (save_relative_paths) | if (save_relative_paths) | ||||
@@ -65,7 +64,12 @@ namespace Tensorflow | |||||
throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + | throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " + | ||||
"checkpoint state. Please use a different save path."); | "checkpoint state. Please use a different save path."); | ||||
File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); | |||||
// File.WriteAllText(coord_checkpoint_filename, ckpt.ToString()); | |||||
File.WriteAllLines(coord_checkpoint_filename, new[] | |||||
{ | |||||
$"model_checkpoint_path: \"{ckpt.ModelCheckpointPath}\"", | |||||
$"all_model_checkpoint_paths: \"{ckpt.AllModelCheckpointPaths[0]}\"", | |||||
}); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -98,7 +102,14 @@ namespace Tensorflow | |||||
all_model_checkpoint_paths.Add(model_checkpoint_path); | all_model_checkpoint_paths.Add(model_checkpoint_path); | ||||
// Relative paths need to be rewritten to be relative to the "save_dir" | // Relative paths need to be rewritten to be relative to the "save_dir" | ||||
// if model_checkpoint_path already contains "save_dir". | |||||
if (model_checkpoint_path.StartsWith(save_dir)) | |||||
{ | |||||
model_checkpoint_path = model_checkpoint_path.Substring(save_dir.Length + 1); | |||||
all_model_checkpoint_paths = all_model_checkpoint_paths | |||||
.Select(x => x.Substring(save_dir.Length + 1)) | |||||
.ToList(); | |||||
} | |||||
var coord_checkpoint_proto = new CheckpointState() | var coord_checkpoint_proto = new CheckpointState() | ||||
{ | { | ||||
@@ -29,14 +29,12 @@ namespace Tensorflow | |||||
{ | { | ||||
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); | var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); | ||||
var meta = meta_graph.import_scoped_meta_graph_with_return_elements( | |||||
var (imported_vars, imported_return_elements) = meta_graph.import_scoped_meta_graph_with_return_elements( | |||||
meta_graph_def, | meta_graph_def, | ||||
clear_devices: clear_devices, | clear_devices: clear_devices, | ||||
import_scope: import_scope, | import_scope: import_scope, | ||||
return_elements: return_elements); | return_elements: return_elements); | ||||
var (imported_vars, imported_return_elements) = meta; | |||||
var saver = _create_saver_from_imported_meta_graph( | var saver = _create_saver_from_imported_meta_graph( | ||||
meta_graph_def, import_scope, imported_vars); | meta_graph_def, import_scope, imported_vars); | ||||