Browse Source

Solve ObjectDetection functions

tags/v0.9
c q 6 years ago
parent
commit
b345efa320
4 changed files with 11 additions and 4 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +2
    -3
      test/TensorFlowNET.Examples/ObjectDetection.cs

+ 6
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -217,6 +217,12 @@ namespace Tensorflow
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
nd = np.array(_bytes).reshape(ndims);
break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)


+ 2
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -67,7 +67,8 @@ namespace Tensorflow
case "Double":
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);
break;
//case "Byte":
case "Byte":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
/*var bb = nd.Data<byte>();
var bytes = Marshal.AllocHGlobal(bb.Length);
Marshal.Copy(bb, 0, bytes, bb.Length);


+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -195,6 +195,7 @@ namespace Tensorflow
case "Double":
return TF_DataType.TF_DOUBLE;
case "Byte":
return TF_DataType.TF_UINT8;
case "String":
return TF_DataType.TF_STRING;
default:


+ 2
- 3
test/TensorFlowNET.Examples/ObjectDetection.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples

public bool Run()
{
//buildOutputImage(null);
PrepareData();

// read in the input image
var imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg"));
@@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples
// download the pbtxt file
if (!File.Exists(Path.Join(modelDir, "mscoco_label_map.pbtxt")))
{
string url = $"https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_label_map.pbtxt";
string url = $"https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
Utility.Web.Download(url, modelDir, "mscoco_label_map.pbtxt");
}
}
@@ -123,7 +123,6 @@ namespace TensorFlowNET.Examples
float score = scores[i];
if (score > MIN_SCORE)
{
//var boxes = resultArr[1].Data<float[,,]>();
float[] boxes = resultArr[1].Data<float>();
float top = boxes[i * 4] * bitmap.Height;
float left = boxes[i * 4 + 1] * bitmap.Width;


Loading…
Cancel
Save