Oceania2018 6 years ago
parent
commit
0595ea1b0f
8 changed files with 26 additions and 12 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  2. +10
    -0
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  3. +9
    -8
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Train/Saving/saver.py.cs
  5. +1
    -1
      test/KerasNET.Test/Keras.UnitTest.csproj
  6. +2
    -0
      test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
  7. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  8. +1
    -1
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -48,7 +48,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
<PackageReference Include="NumSharp" Version="0.10.2" />
<PackageReference Include="NumSharp" Version="0.10.3" />
</ItemGroup>

<ItemGroup>


+ 10
- 0
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -16,6 +16,12 @@ namespace Tensorflow
_write_version = write_version;
}

/// <summary>
/// Create an Op to save 'saveables'.
/// </summary>
/// <param name="filename_tensor"></param>
/// <param name="saveables"></param>
/// <returns></returns>
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables)
{
var tensor_names = new List<string>();
@@ -105,6 +111,10 @@ namespace Tensorflow
}

var graph = ops.get_default_graph();
// Do some sanity checking on collections containing
// PartitionedVariables. If a saved collection has a PartitionedVariable,
// the GraphDef needs to include concat ops to get the value (or there'll
// be a lookup error on load).
var check_collection_list = graph.get_all_collection_keys();
foreach (var collection_type in check_collection_list)
{


+ 9
- 8
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -158,7 +158,10 @@ namespace Tensorflow
string model_checkpoint_path = "";
string checkpoint_file = "";

checkpoint_file = $"{save_path}-{global_step}";
if (global_step > 0)
checkpoint_file = $"{save_path}-{global_step}";
else
checkpoint_file = save_path;

var save_path_parent = Path.GetDirectoryName(save_path);

@@ -291,15 +294,13 @@ namespace Tensorflow
if (_saver_def.MaxToKeep <= 0) return;

// Remove first from list if the same name was used before.
foreach (var p in _last_checkpoints)
if (latest_save_path == _CheckpointFilename((p.Key, p.Value)))
_last_checkpoints.Remove(p.Key);

// Append new path to list
_last_checkpoints.Add(latest_save_path, Python.time());
var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value)));
if (_existed_checkpoints.Key != null)
_last_checkpoints.Remove(_existed_checkpoints.Key);
_last_checkpoints.Add(latest_save_path, time());

// If more than max_to_keep, remove oldest.
if(_last_checkpoints.Count > _saver_def.MaxToKeep)
if (_last_checkpoints.Count > _saver_def.MaxToKeep)
{
var first = _last_checkpoints.First();
_last_checkpoints.Remove(first.Key);


+ 1
- 1
src/TensorFlowNET.Core/Train/Saving/saver.py.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow
var saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars);

return (saver, null);
return (saver, imported_return_elements);
}

/// <summary>


+ 1
- 1
test/KerasNET.Test/Keras.UnitTest.csproj View File

@@ -26,7 +26,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
</ItemGroup>


+ 2
- 0
test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs View File

@@ -105,6 +105,8 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a train saver that is used to restore values into an eval graph
// when exporting models.
var train_saver = tf.train.Saver();
train_saver.save(sess, CHECKPOINT_NAME);

sw.Restart();

for (int i = 0; i < how_many_training_steps; i++)


+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -17,6 +17,7 @@
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
<PackageReference Include="SharpZipLib" Version="1.1.0" />
<PackageReference Include="System.Drawing.Common" Version="4.5.1" />
<PackageReference Include="TensorFlow.NET" Version="0.8.0" />
</ItemGroup>

<ItemGroup>


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -16,7 +16,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
</ItemGroup>


Loading…
Cancel
Save