From b345efa320416fdba7183b5c13446d643417b6a8 Mon Sep 17 00:00:00 2001 From: c q Date: Sun, 28 Apr 2019 10:01:17 +0800 Subject: [PATCH] Solve ObjectDetection functions --- src/TensorFlowNET.Core/Sessions/BaseSession.cs | 6 ++++++ src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs | 3 ++- src/TensorFlowNET.Core/Tensors/Tensor.cs | 1 + test/TensorFlowNET.Examples/ObjectDetection.cs | 5 ++--- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 3c0c7486..f331f6a3 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -217,6 +217,12 @@ namespace Tensorflow var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); nd = np.array(str).reshape(); break; + case TF_DataType.TF_UINT8: + var _bytes = new byte[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(_bytes).reshape(ndims); + break; case TF_DataType.TF_INT16: var shorts = new short[tensor.size]; for (ulong i = 0; i < tensor.size; i++) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 5f26becd..3b8b65dd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -67,7 +67,8 @@ namespace Tensorflow case "Double": Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; - //case "Byte": + case "Byte": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); /*var bb = nd.Data(); var bytes = Marshal.AllocHGlobal(bb.Length); Marshal.Copy(bb, 0, bytes, bb.Length); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 0c2e84d5..46a0e264 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -195,6 +195,7 @@ namespace Tensorflow case "Double": return TF_DataType.TF_DOUBLE; case "Byte": + return TF_DataType.TF_UINT8; case "String": return TF_DataType.TF_STRING; default: diff --git a/test/TensorFlowNET.Examples/ObjectDetection.cs b/test/TensorFlowNET.Examples/ObjectDetection.cs index ae0e87c6..44662218 100644 --- a/test/TensorFlowNET.Examples/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ObjectDetection.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples public bool Run() { - //buildOutputImage(null); + PrepareData(); // read in the input image var imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg")); @@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples // download the pbtxt file if (!File.Exists(Path.Join(modelDir, "mscoco_label_map.pbtxt"))) { - string url = $"https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_label_map.pbtxt"; + string url = $"https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt"; Utility.Web.Download(url, modelDir, "mscoco_label_map.pbtxt"); } } @@ -123,7 +123,6 @@ namespace TensorFlowNET.Examples float score = scores[i]; if (score > MIN_SCORE) { - //var boxes = resultArr[1].Data(); float[] boxes = resultArr[1].Data(); float top = boxes[i * 4] * bitmap.Height; float left = boxes[i * 4 + 1] * bitmap.Width;