Browse Source

add Prediction function for Transfer Learning

tags/v0.12
Oceania2018 6 years ago
parent
commit
8f9f5e36dd
4 changed files with 49 additions and 15 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.graph.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  3. +7
    -7
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +38
    -4
      test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.graph.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
return ops.get_default_graph();
}

public static Graph Graph() => new Graph();
public static Graph Graph()
=> new Graph();
}
}

+ 2
- 2
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -55,11 +55,11 @@ namespace Tensorflow
return Status;
}

public static Graph ImportFromPB(string file_path)
public static Graph ImportFromPB(string file_path, string name = null)
{
var graph = tf.Graph().as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
importer.import_graph_def(graph_def);
importer.import_graph_def(graph_def, name: name);
return graph;
}
}


+ 7
- 7
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.10.2</Version>
<Version>0.10.3</Version>
<Authors>Haiping Chen, Meinrad Recheis</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,9 +16,8 @@
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io
Medium: https://medium.com/scisharp</Description>
<AssemblyVersion>0.10.2.0</AssemblyVersion>
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.10.3.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.9.0:

1. Added full connected Convolution Neural Network example.
@@ -29,11 +28,12 @@ Medium: https://medium.com/scisharp</Description>
6. Add StridedSliceGrad.
7. Add BatchMatMulGrad.
8. Upgrade NumSharp.
9. Fix strided_slice_grad type convention error.</PackageReleaseNotes>
9. Fix strided_slice_grad type convention error.
10. Add AbsGrad.</PackageReleaseNotes>
<LangVersion>7.2</LangVersion>
<FileVersion>0.10.2.0</FileVersion>
<FileVersion>0.10.3.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>false</PackageRequireLicenseAcceptance>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
</PropertyGroup>


+ 38
- 4
test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs View File

@@ -80,6 +80,16 @@ namespace TensorFlowNET.Examples.ImageProcess
{
PrepareData();

#region For debug purpose
// predict images
Predict(null);

// load saved pb and test new images.
Test(null);
#endregion

var graph = IsImportingGraph ? ImportGraph() : BuildGraph();

with(tf.Session(graph), sess =>
@@ -708,14 +718,38 @@ namespace TensorFlowNET.Examples.ImageProcess
File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
}

public void Predict(Session sess)
public void Predict(Session sess_)
{
throw new NotImplementedException();
if (!File.Exists(output_graph))
return;

var graph = Graph.ImportFromPB(output_graph, "");

Tensor input_layer = graph.OperationByName("input/BottleneckInputPlaceholder");
Tensor output_layer = graph.OperationByName("final_result");

with(tf.Session(graph), sess =>
{
// load images into NDArray in a matrix[image_num, features]
var nd = np.arange(2048f).reshape(1, 2048); // replace this line
var result = sess.run(output_layer, new FeedItem(input_layer, nd));
});
}

public void Test(Session sess)
public void Test(Session sess_)
{
throw new NotImplementedException();
if (!File.Exists(output_graph))
return;

var graph = Graph.ImportFromPB(output_graph);
var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();

with(tf.Session(graph), sess =>
{
(test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
bottleneck_tensor);
});
}
}
}

Loading…
Cancel
Save