Oceania2018 6 years ago
parent
commit
0595ea1b0f
8 changed files with 26 additions and 12 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  2. +10
    -0
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  3. +9
    -8
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Train/Saving/saver.py.cs
  5. +1
    -1
      test/KerasNET.Test/Keras.UnitTest.csproj
  6. +2
    -0
      test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
  7. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  8. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -48,7 +48,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>


<ItemGroup> <ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.8.0" /> <PackageReference Include="Google.Protobuf" Version="3.8.0" />
<PackageReference Include="NumSharp" Version="0.10.2" />
<PackageReference Include="NumSharp" Version="0.10.3" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 10
- 0
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -16,6 +16,12 @@ namespace Tensorflow
_write_version = write_version; _write_version = write_version;
} }


/// <summary>
/// Create an Op to save 'saveables'.
/// </summary>
/// <param name="filename_tensor"></param>
/// <param name="saveables"></param>
/// <returns></returns>
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables) public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables)
{ {
var tensor_names = new List<string>(); var tensor_names = new List<string>();
@@ -105,6 +111,10 @@ namespace Tensorflow
} }


var graph = ops.get_default_graph(); var graph = ops.get_default_graph();
// Do some sanity checking on collections containing
// PartitionedVariables. If a saved collection has a PartitionedVariable,
// the GraphDef needs to include concat ops to get the value (or there'll
// be a lookup error on load).
var check_collection_list = graph.get_all_collection_keys(); var check_collection_list = graph.get_all_collection_keys();
foreach (var collection_type in check_collection_list) foreach (var collection_type in check_collection_list)
{ {


+ 9
- 8
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -158,7 +158,10 @@ namespace Tensorflow
string model_checkpoint_path = ""; string model_checkpoint_path = "";
string checkpoint_file = ""; string checkpoint_file = "";


checkpoint_file = $"{save_path}-{global_step}";
if (global_step > 0)
checkpoint_file = $"{save_path}-{global_step}";
else
checkpoint_file = save_path;


var save_path_parent = Path.GetDirectoryName(save_path); var save_path_parent = Path.GetDirectoryName(save_path);


@@ -291,15 +294,13 @@ namespace Tensorflow
if (_saver_def.MaxToKeep <= 0) return; if (_saver_def.MaxToKeep <= 0) return;


// Remove first from list if the same name was used before. // Remove first from list if the same name was used before.
foreach (var p in _last_checkpoints)
if (latest_save_path == _CheckpointFilename((p.Key, p.Value)))
_last_checkpoints.Remove(p.Key);

// Append new path to list
_last_checkpoints.Add(latest_save_path, Python.time());
var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value)));
if (_existed_checkpoints.Key != null)
_last_checkpoints.Remove(_existed_checkpoints.Key);
_last_checkpoints.Add(latest_save_path, time());


// If more than max_to_keep, remove oldest. // If more than max_to_keep, remove oldest.
if(_last_checkpoints.Count > _saver_def.MaxToKeep)
if (_last_checkpoints.Count > _saver_def.MaxToKeep)
{ {
var first = _last_checkpoints.First(); var first = _last_checkpoints.First();
_last_checkpoints.Remove(first.Key); _last_checkpoints.Remove(first.Key);


+ 1
- 1
src/TensorFlowNET.Core/Train/Saving/saver.py.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow
var saver = _create_saver_from_imported_meta_graph( var saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars); meta_graph_def, import_scope, imported_vars);


return (saver, null);
return (saver, imported_return_elements);
} }


/// <summary> /// <summary>


+ 1
- 1
test/KerasNET.Test/Keras.UnitTest.csproj View File

@@ -26,7 +26,7 @@
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> <PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
</ItemGroup> </ItemGroup>


+ 2
- 0
test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs View File

@@ -105,6 +105,8 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a train saver that is used to restore values into an eval graph // Create a train saver that is used to restore values into an eval graph
// when exporting models. // when exporting models.
var train_saver = tf.train.Saver(); var train_saver = tf.train.Saver();
train_saver.save(sess, CHECKPOINT_NAME);

sw.Restart(); sw.Restart();


for (int i = 0; i < how_many_training_steps; i++) for (int i = 0; i < how_many_training_steps; i++)


+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -17,6 +17,7 @@
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> <PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Include="SharpZipLib" Version="1.1.0" /> <PackageReference Include="SharpZipLib" Version="1.1.0" />
<PackageReference Include="System.Drawing.Common" Version="4.5.1" /> <PackageReference Include="System.Drawing.Common" Version="4.5.1" />
<PackageReference Include="TensorFlow.NET" Version="0.8.0" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -16,7 +16,7 @@
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> <PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
</ItemGroup> </ItemGroup>


Loading…
Cancel
Save