Browse Source

fix SessionTest.Session #247

tags/v0.9
Oceania2018 6 years ago
parent
commit
ddbbe068e7
2 changed files with 16 additions and 12 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +15
    -11
      test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs

+ 1
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow

public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
{
_graph = graph is null ? ops.get_default_graph() : g;
_graph = g is null ? ops.get_default_graph() : g;

_target = UTF8Encoding.UTF8.GetBytes(target);



+ 15
- 11
test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs View File

@@ -29,14 +29,12 @@ namespace TensorFlowNET.Examples
string dir = "ImageRecognitionInception";
string pbFile = "tensorflow_inception_graph.pb";
string labelFile = "imagenet_comp_graph_label_strings.txt";
List<NDArray> file_ndarrays = new List<NDArray>();

public bool Run()
{
PrepareData();

var labels = File.ReadAllLines(Path.Join(dir, labelFile));
var files = Directory.GetFiles(Path.Join(dir, "img"));

var graph = new Graph().as_default();
//import GraphDef from pb file
graph.Import(Path.Join(dir, pbFile));
@@ -47,23 +45,21 @@ namespace TensorFlowNET.Examples
var input_operation = graph.OperationByName(input_name);
var output_operation = graph.OperationByName(output_name);

var labels = File.ReadAllLines(Path.Join(dir, labelFile));
var result_labels = new List<string>();
var sw = new Stopwatch();

with(tf.Session(graph), sess =>
{
foreach (var file in files)
foreach (var nd in file_ndarrays)
{
sw.Restart();

// load image file
var tensor = ReadTensorFromImageFile(file);
var results = sess.run(output_operation.outputs[0], new FeedItem(input_operation.outputs[0], tensor));
var results = sess.run(output_operation.outputs[0], new FeedItem(input_operation.outputs[0], nd));
results = np.squeeze(results);
int idx = np.argmax(results);

Console.WriteLine($"{file.Split(Path.DirectorySeparatorChar).Last()}: {labels[idx]} {results[idx]} in {sw.ElapsedMilliseconds}ms", Color.Tan);

Console.WriteLine($"{labels[idx]} {results[idx]} in {sw.ElapsedMilliseconds}ms", Color.Tan);
result_labels.Add(labels[idx]);
}
});
@@ -77,7 +73,7 @@ namespace TensorFlowNET.Examples
int input_mean = 117,
int input_std = 1)
{
return with(tf.Graph(), graph =>
return with(tf.Graph().as_default(), graph =>
{
var file_reader = tf.read_file(file_name, "file_reader");
var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 3, name: "DecodeJpeg");
@@ -110,6 +106,14 @@ namespace TensorFlowNET.Examples

url = $"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/shasta-daisy.jpg";
Utility.Web.Download(url, Path.Join(dir, "img"), "shasta-daisy.jpg");

// load image file
var files = Directory.GetFiles(Path.Join(dir, "img"));
for (int i = 0; i < files.Length; i++)
{
var nd = ReadTensorFromImageFile(files[i]);
file_ndarrays.Add(nd);
}
}
}
}

Loading…
Cancel
Save