Browse Source

Add object detection example

Add byte type support in Session and Tensor
Add pbtxtParser class to utility
tags/v0.9
c q 6 years ago
parent
commit
f8324e87e5
7 changed files with 266 additions and 2 deletions
  1. +6
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +9
    -0
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  4. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  5. +173
    -0
      test/TensorFlowNET.Examples/ObjectDetection.cs
  6. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  7. +74
    -0
      test/TensorFlowNET.Examples/Utility/PbtxtParser.cs

+ 6
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -25,7 +25,6 @@ namespace Tensorflow
else
{
_graph = graph;

}

_target = UTF8Encoding.UTF8.GetBytes(target);
@@ -212,6 +211,12 @@ namespace Tensorflow
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)


+ 9
- 0
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -30,6 +30,15 @@ namespace Tensorflow
case "Single":
nd.SetData(values.Select(x => (float)x).ToArray());
break;
case "NDArray":
// nd.SetData<NDArray>(values.ToArray());
//NDArray[] arr = new NDArray[values.Count];
//for (int i=0; i<values.Count; i++)
NDArray[] arr = values.Select(x => (NDArray)x).ToArray();
nd = new NDArray(arr);
break;
default:
break;
}

return nd;


+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -67,7 +67,8 @@ namespace Tensorflow
case "Double":
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);
break;
//case "Byte":
case "Byte":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
/*var bb = nd.Data<byte>();
var bytes = Marshal.AllocHGlobal(bb.Length);
Marshal.Copy(bb, 0, bytes, bb.Length);


+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -195,6 +195,7 @@ namespace Tensorflow
case "Double":
return TF_DataType.TF_DOUBLE;
case "Byte":
return TF_DataType.TF_UINT8;
case "String":
return TF_DataType.TF_STRING;
default:


+ 173
- 0
test/TensorFlowNET.Examples/ObjectDetection.cs View File

@@ -0,0 +1,173 @@
using Newtonsoft.Json;
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using System.Drawing;
using System.Drawing.Drawing2D;
using System.Linq;

namespace TensorFlowNET.Examples
{

public class ObjectDetection : Python, IExample
{
public int Priority => 7;
public bool Enabled { get; set; } = true;
public string Name => "Image Recognition";
public float MIN_SCORE = 0.5f;

string modelDir = "ssd_mobilenet_v1_coco_2018_01_28";
string imageDir = "images";
string pbFile = "frozen_inference_graph.pb";
string labelFile = "mscoco_label_map.pbtxt";
string picFile = "input.jpg";

public bool Run()
{
//buildOutputImage(null);

// read in the input image
var imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg"));

var graph = new Graph().as_default();
graph.Import(Path.Join(modelDir, pbFile));

var tensorNum = graph.OperationByName("num_detections").outputs[0];
var tensorBoxes = graph.OperationByName("detection_boxes").outputs[0];
var tensorScores = graph.OperationByName("detection_scores").outputs[0];
var tensorClasses = graph.OperationByName("detection_classes").outputs[0];

var imgTensor = graph.OperationByName("image_tensor").outputs[0];


Tensor[] outTensorArr = new Tensor[] { tensorNum, tensorBoxes, tensorScores, tensorClasses };

with(tf.Session(graph), sess =>
{
var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr));

//NDArray scores = results.Array.GetValue(2) as NDArray;

//floatscores.Data<float>();
NDArray[] resultArr = results.Data<NDArray>();

//float[] scores = resultArr[2].Data<float>();
buildOutputImage(resultArr);
});

return true;
}

public void PrepareData()
{
if (!Directory.Exists(modelDir))
Directory.CreateDirectory(modelDir);

if (!File.Exists(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz")))
{
// get model file
string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz";

Utility.Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz");
}

if (!File.Exists(Path.Join(modelDir, "frozen_inference_graph.pb")))
{
Utility.Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./");
}


// download sample picture
if (!Directory.Exists(imageDir))
Directory.CreateDirectory(imageDir);

if (!File.Exists(Path.Join(imageDir, "input.jpg")))
{
string url = $"https://github.com/tensorflow/models/raw/master/research/object_detection/test_images/image2.jpg";
Utility.Web.Download(url, imageDir, "input.jpg");
}

// download the pbtxt file
if (!File.Exists(Path.Join(modelDir, "mscoco_label_map.pbtxt")))
{
string url = $"https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_label_map.pbtxt";
Utility.Web.Download(url, modelDir, "mscoco_label_map.pbtxt");
}
}

private NDArray ReadTensorFromImageFile(string file_name)
{
return with(tf.Graph().as_default(), graph =>
{
var file_reader = tf.read_file(file_name, "file_reader");
var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 3, name: "DecodeJpeg");
var casted = tf.cast(decodeJpeg, TF_DataType.TF_UINT8);
var dims_expander = tf.expand_dims(casted, 0);
return with(tf.Session(graph), sess => sess.run(dims_expander));
});
}

private void buildOutputImage(NDArray[] resultArr)
{
// get pbtxt items
PbtxtItems pbTxtItems = PbtxtParser.ParsePbtxtFile(Path.Join(modelDir, "mscoco_label_map.pbtxt"));

// get bitmap
Bitmap bitmap = new Bitmap(Path.Join(imageDir, "input.jpg"));

float[] scores = resultArr[2].Data<float>();

for (int i=0; i<scores.Length; i++)
{
float score = scores[i];
if (score > MIN_SCORE)
{
//var boxes = resultArr[1].Data<float[,,]>();
float[] boxes = resultArr[1].Data<float>();
float top = boxes[i * 4] * bitmap.Height;
float left = boxes[i * 4 + 1] * bitmap.Width;
float bottom = boxes[i * 4 + 2] * bitmap.Height;
float right = boxes[i * 4 + 3] * bitmap.Width;

Rectangle rect = new Rectangle()
{
X = (int)left,
Y = (int)top,
Width = (int)(right - left),
Height = (int)(bottom - top)
};

float[] ids = resultArr[3].Data<float>();

string name = pbTxtItems.items.Where(w => w.id == (int)ids[i]).Select(s=>s.display_name).FirstOrDefault();

drawObjectOnBitmap(bitmap, rect, score, name);
}
}

bitmap.Save(Path.Join(imageDir, "output.jpg"));
}

private void drawObjectOnBitmap(Bitmap bmp, Rectangle rect, float score, string name)
{
using (Graphics graphic = Graphics.FromImage(bmp))
{
graphic.SmoothingMode = SmoothingMode.AntiAlias;
using (Pen pen = new Pen(Color.Red, 2))
{
graphic.DrawRectangle(pen, rect);

Point p = new Point(rect.Right + 5, rect.Top + 5);
string text = string.Format("{0}:{1}%", name, (int)(score * 100));
graphic.DrawString(text, new Font("Verdana", 8), Brushes.Red, p);
}
}
}
}
}

+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -10,6 +10,7 @@
<PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Include="SharpZipLib" Version="1.1.0" />
<PackageReference Include="System.Drawing.Common" Version="4.5.1" />
</ItemGroup>

<ItemGroup>


+ 74
- 0
test/TensorFlowNET.Examples/Utility/PbtxtParser.cs View File

@@ -0,0 +1,74 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace TensorFlowNET.Examples.Utility
{
public class PbtxtItem
{
public string name { get; set; }
public int id { get; set; }
public string display_name { get; set; }
}
public class PbtxtItems
{
public List<PbtxtItem> items { get; set; }
}

public class PbtxtParser
{
public static PbtxtItems ParsePbtxtFile(string filePath)
{
string line;
string newText = "{\"items\":[";

try
{
using (System.IO.StreamReader reader = new System.IO.StreamReader(filePath))
{

while ((line = reader.ReadLine()) != null)
{
string newline = string.Empty;

if (line.Contains("{"))
{
newline = line.Replace("item", "").Trim();
//newText += line.Insert(line.IndexOf("=") + 1, "\"") + "\",";
newText += newline;
}
else if (line.Contains("}"))
{
newText = newText.Remove(newText.Length - 1);
newText += line;
newText += ",";
}
else
{
newline = line.Replace(":", "\":").Trim();
newline = "\"" + newline;// newline.Insert(0, "\"");
newline += ",";

newText += newline;
}

}

newText = newText.Remove(newText.Length - 1);
newText += "]}";

reader.Close();
}

PbtxtItems items = JsonConvert.DeserializeObject<PbtxtItems>(newText);

return items;
}
catch (Exception ex)
{
return null;
}
}
}
}

Loading…
Cancel
Save