@@ -48,7 +48,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.8.0" /> | <PackageReference Include="Google.Protobuf" Version="3.8.0" /> | ||||
<PackageReference Include="NumSharp" Version="0.10.2" /> | |||||
<PackageReference Include="NumSharp" Version="0.10.3" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -16,6 +16,12 @@ namespace Tensorflow | |||||
_write_version = write_version; | _write_version = write_version; | ||||
} | } | ||||
/// <summary> | |||||
/// Create an Op to save 'saveables'. | |||||
/// </summary> | |||||
/// <param name="filename_tensor"></param> | |||||
/// <param name="saveables"></param> | |||||
/// <returns></returns> | |||||
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) | ||||
{ | { | ||||
var tensor_names = new List<string>(); | var tensor_names = new List<string>(); | ||||
@@ -105,6 +111,10 @@ namespace Tensorflow | |||||
} | } | ||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
// Do some sanity checking on collections containing | |||||
// PartitionedVariables. If a saved collection has a PartitionedVariable, | |||||
// the GraphDef needs to include concat ops to get the value (or there'll | |||||
// be a lookup error on load). | |||||
var check_collection_list = graph.get_all_collection_keys(); | var check_collection_list = graph.get_all_collection_keys(); | ||||
foreach (var collection_type in check_collection_list) | foreach (var collection_type in check_collection_list) | ||||
{ | { | ||||
@@ -158,7 +158,10 @@ namespace Tensorflow | |||||
string model_checkpoint_path = ""; | string model_checkpoint_path = ""; | ||||
string checkpoint_file = ""; | string checkpoint_file = ""; | ||||
checkpoint_file = $"{save_path}-{global_step}"; | |||||
if (global_step > 0) | |||||
checkpoint_file = $"{save_path}-{global_step}"; | |||||
else | |||||
checkpoint_file = save_path; | |||||
var save_path_parent = Path.GetDirectoryName(save_path); | var save_path_parent = Path.GetDirectoryName(save_path); | ||||
@@ -291,15 +294,13 @@ namespace Tensorflow | |||||
if (_saver_def.MaxToKeep <= 0) return; | if (_saver_def.MaxToKeep <= 0) return; | ||||
// Remove first from list if the same name was used before. | // Remove first from list if the same name was used before. | ||||
foreach (var p in _last_checkpoints) | |||||
if (latest_save_path == _CheckpointFilename((p.Key, p.Value))) | |||||
_last_checkpoints.Remove(p.Key); | |||||
// Append new path to list | |||||
_last_checkpoints.Add(latest_save_path, Python.time()); | |||||
var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value))); | |||||
if (_existed_checkpoints.Key != null) | |||||
_last_checkpoints.Remove(_existed_checkpoints.Key); | |||||
_last_checkpoints.Add(latest_save_path, time()); | |||||
// If more than max_to_keep, remove oldest. | // If more than max_to_keep, remove oldest. | ||||
if(_last_checkpoints.Count > _saver_def.MaxToKeep) | |||||
if (_last_checkpoints.Count > _saver_def.MaxToKeep) | |||||
{ | { | ||||
var first = _last_checkpoints.First(); | var first = _last_checkpoints.First(); | ||||
_last_checkpoints.Remove(first.Key); | _last_checkpoints.Remove(first.Key); | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow | |||||
var saver = _create_saver_from_imported_meta_graph( | var saver = _create_saver_from_imported_meta_graph( | ||||
meta_graph_def, import_scope, imported_vars); | meta_graph_def, import_scope, imported_vars); | ||||
return (saver, null); | |||||
return (saver, imported_return_elements); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -26,7 +26,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.0" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -105,6 +105,8 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
// Create a train saver that is used to restore values into an eval graph | // Create a train saver that is used to restore values into an eval graph | ||||
// when exporting models. | // when exporting models. | ||||
var train_saver = tf.train.Saver(); | var train_saver = tf.train.Saver(); | ||||
train_saver.save(sess, CHECKPOINT_NAME); | |||||
sw.Restart(); | sw.Restart(); | ||||
for (int i = 0; i < how_many_training_steps; i++) | for (int i = 0; i < how_many_training_steps; i++) | ||||
@@ -17,6 +17,7 @@ | |||||
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | ||||
<PackageReference Include="SharpZipLib" Version="1.1.0" /> | <PackageReference Include="SharpZipLib" Version="1.1.0" /> | ||||
<PackageReference Include="System.Drawing.Common" Version="4.5.1" /> | <PackageReference Include="System.Drawing.Common" Version="4.5.1" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.8.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -16,7 +16,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||