diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 1e532dc8..cecdbd38 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -48,7 +48,7 @@ Docs: https://tensorflownet.readthedocs.io
-
+
diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
index f0db80ce..13886401 100644
--- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
@@ -16,6 +16,12 @@ namespace Tensorflow
_write_version = write_version;
}
+ ///
+ /// Create an Op to save 'saveables'.
+ ///
+ ///
+ ///
+ ///
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables)
{
var tensor_names = new List();
@@ -105,6 +111,10 @@ namespace Tensorflow
}
var graph = ops.get_default_graph();
+ // Do some sanity checking on collections containing
+ // PartitionedVariables. If a saved collection has a PartitionedVariable,
+ // the GraphDef needs to include concat ops to get the value (or there'll
+ // be a lookup error on load).
var check_collection_list = graph.get_all_collection_keys();
foreach (var collection_type in check_collection_list)
{
diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs
index 0f3a2ab8..b367eb1d 100644
--- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs
@@ -158,7 +158,10 @@ namespace Tensorflow
string model_checkpoint_path = "";
string checkpoint_file = "";
- checkpoint_file = $"{save_path}-{global_step}";
+ if (global_step > 0)
+ checkpoint_file = $"{save_path}-{global_step}";
+ else
+ checkpoint_file = save_path;
var save_path_parent = Path.GetDirectoryName(save_path);
@@ -291,15 +294,13 @@ namespace Tensorflow
if (_saver_def.MaxToKeep <= 0) return;
// Remove first from list if the same name was used before.
- foreach (var p in _last_checkpoints)
- if (latest_save_path == _CheckpointFilename((p.Key, p.Value)))
- _last_checkpoints.Remove(p.Key);
-
- // Append new path to list
- _last_checkpoints.Add(latest_save_path, Python.time());
+ var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value)));
+ if (_existed_checkpoints.Key != null)
+ _last_checkpoints.Remove(_existed_checkpoints.Key);
+ _last_checkpoints.Add(latest_save_path, time());
// If more than max_to_keep, remove oldest.
- if(_last_checkpoints.Count > _saver_def.MaxToKeep)
+ if (_last_checkpoints.Count > _saver_def.MaxToKeep)
{
var first = _last_checkpoints.First();
_last_checkpoints.Remove(first.Key);
diff --git a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs
index 303f41a4..c21b954f 100644
--- a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs
@@ -25,7 +25,7 @@ namespace Tensorflow
var saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars);
- return (saver, null);
+ return (saver, imported_return_elements);
}
///
diff --git a/test/KerasNET.Test/Keras.UnitTest.csproj b/test/KerasNET.Test/Keras.UnitTest.csproj
index 89a5425c..1e9a2253 100644
--- a/test/KerasNET.Test/Keras.UnitTest.csproj
+++ b/test/KerasNET.Test/Keras.UnitTest.csproj
@@ -26,7 +26,7 @@
-
+
diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
index 76fd7106..3793027a 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
@@ -105,6 +105,8 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a train saver that is used to restore values into an eval graph
// when exporting models.
var train_saver = tf.train.Saver();
+ train_saver.save(sess, CHECKPOINT_NAME);
+
sw.Restart();
for (int i = 0; i < how_many_training_steps; i++)
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index f6cbea99..9f4e473d 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -17,6 +17,7 @@
+
diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
index 5e619165..f76ca132 100644
--- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
@@ -16,7 +16,7 @@
-
+