Browse Source

fix freeze_graph output_node_names

tags/v0.20
Oceania2018 5 years ago
parent
commit
21cf2be660
3 changed files with 9 additions and 7 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +3
    -3
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  3. +4
    -2
      src/TensorFlowNET.Core/Training/Saving/saver.py.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -56,8 +56,8 @@ namespace Tensorflow
public Graph load_graph(string freeze_graph_pb)
=> saver.load_graph(freeze_graph_pb);

public string freeze_graph(string checkpoint_dir, string output_pb_name)
=> saver.freeze_graph(checkpoint_dir, output_pb_name);
public string freeze_graph(string checkpoint_dir, string output_pb_name, string[] output_node_names)
=> saver.freeze_graph(checkpoint_dir, output_pb_name, output_node_names);

public Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,


+ 3
- 3
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.1</TargetTensorFlow>
<Version>0.14.1</Version>
<Version>0.14.1.1</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -18,12 +18,12 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.14.1.0</AssemblyVersion>
<AssemblyVersion>0.14.1.1</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.14.0:
1: Add TransformGraphWithStringInputs.
2: tf.trainer.load_graph, tf.trainer.freeze_graph</PackageReleaseNotes>
<LangVersion>7.3</LangVersion>
<FileVersion>0.14.1.0</FileVersion>
<FileVersion>0.14.1.1</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 4
- 2
src/TensorFlowNET.Core/Training/Saving/saver.py.cs View File

@@ -85,7 +85,9 @@ namespace Tensorflow
}
}

public static string freeze_graph(string checkpoint_dir, string output_pb_name)
public static string freeze_graph(string checkpoint_dir,
string output_pb_name,
string[] output_node_names)
{
var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir);
if (!File.Exists($"{checkpoint}.meta")) return null;
@@ -99,7 +101,7 @@ namespace Tensorflow
saver.restore(sess, checkpoint);
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(),
new string[] { "output/ArgMax" });
output_node_names);
Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb;


Loading…
Cancel
Save